diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index 39195d62a5..66bdbcf17a 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -247,6 +247,33 @@ static inline __m128 masked_read(int d, const float* x) { namespace { +/// helper function +inline float horizontal_sum(const __m128 v) { + // say, v is [x0, x1, x2, x3] + + // v0 is [x2, x3, ..., ...] + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); + // v1 is [x0 + x2, x1 + x3, ..., ...] + const __m128 v1 = _mm_add_ps(v, v0); + // v2 is [x1 + x3, ..., .... ,...] + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + // v3 is [x0 + x1 + x2 + x3, ..., ..., ...] + const __m128 v3 = _mm_add_ps(v1, v2); + // return v3[0] + return _mm_cvtss_f32(v3); +} + +#ifdef __AVX2__ +/// helper function for AVX2 +inline float horizontal_sum(const __m256 v) { + // add high and low parts + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + // perform horizontal sum on v0 + return horizontal_sum(v0); +} +#endif + /// Function that does a component-wise operation between x and y /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny /// functions below @@ -260,6 +287,13 @@ struct ElementOpL2 { __m128 tmp = _mm_sub_ps(x, y); return _mm_mul_ps(tmp, tmp); } + +#ifdef __AVX2__ + static __m256 op(__m256 x, __m256 y) { + __m256 tmp = _mm256_sub_ps(x, y); + return _mm256_mul_ps(tmp, tmp); + } +#endif }; /// Function that does a component-wise operation between x and y @@ -272,6 +306,12 @@ struct ElementOpIP { static __m128 op(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } + +#ifdef __AVX2__ + static __m256 op(__m256 x, __m256 y) { + return _mm256_mul_ps(x, y); + } +#endif }; template @@ -314,6 +354,131 @@ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) { } } +#ifdef __AVX2__ + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D2-vectors per loop. + _mm_prefetch(y, _MM_HINT_T0); + _mm_prefetch(y + 16, _MM_HINT_T0); + + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (i = 0; i < ny8 * 8; i += 8) { + _mm_prefetch(y + 32, _MM_HINT_T0); + + // load 8x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 16; + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float distance = x0 * y[0] + x1 * y[1]; + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D2-vectors per loop. + _mm_prefetch(y, _MM_HINT_T0); + _mm_prefetch(y + 16, _MM_HINT_T0); + + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (i = 0; i < ny8 * 8; i += 8) { + _mm_prefetch(y + 32, _MM_HINT_T0); + + // load 8x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 16; + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + dis[i] = distance; + } + } +} + +#endif + template void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { __m128 x0 = _mm_loadu_ps(x); @@ -321,17 +486,12 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { for (size_t i = 0; i < ny; i++) { __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); + dis[i] = horizontal_sum(accu); } } #ifdef __AVX2__ -// Specialized versions for AVX2 for any CPUs that support gather/scatter. -// Todo: implement fvec_op_ny_Dxxx in the same way. - template <> void fvec_op_ny_D4( float* dis, @@ -343,16 +503,9 @@ void fvec_op_ny_D4( if (ny8 > 0) { // process 8 D4-vectors per loop. - _mm_prefetch(y, _MM_HINT_NTA); - _mm_prefetch(y + 16, _MM_HINT_NTA); - - // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0]) const __m256 m0 = _mm256_set1_ps(x[0]); - // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1]) const __m256 m1 = _mm256_set1_ps(x[1]); - // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2]) const __m256 m2 = _mm256_set1_ps(x[2]); - // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3]) const __m256 m3 = _mm256_set1_ps(x[3]); for (i = 0; i < ny8 * 8; i += 8) { @@ -395,9 +548,7 @@ void fvec_op_ny_D4( for (; i < ny; i++) { __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y)); y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); + dis[i] = horizontal_sum(accu); } } } @@ -413,16 +564,9 @@ void fvec_op_ny_D4( if (ny8 > 0) { // process 8 D4-vectors per loop. - _mm_prefetch(y, _MM_HINT_NTA); - _mm_prefetch(y + 16, _MM_HINT_NTA); - - // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0]) const __m256 m0 = _mm256_set1_ps(x[0]); - // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1]) const __m256 m1 = _mm256_set1_ps(x[1]); - // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2]) const __m256 m2 = _mm256_set1_ps(x[2]); - // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3]) const __m256 m3 = _mm256_set1_ps(x[3]); for (i = 0; i < ny8 * 8; i += 8) { @@ -471,9 +615,7 @@ void fvec_op_ny_D4( for (; i < ny; i++) { __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); + dis[i] = horizontal_sum(accu); } } } @@ -496,6 +638,182 @@ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) { } } +#ifdef __AVX2__ + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D8-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + distances = _mm256_fmadd_ps(m2, v2, distances); + distances = _mm256_fmadd_ps(m3, v3, distances); + distances = _mm256_fmadd_ps(m4, v4, distances); + distances = _mm256_fmadd_ps(m5, v5, distances); + distances = _mm256_fmadd_ps(m6, v6, distances); + distances = _mm256_fmadd_ps(m7, v7, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 64; + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D8-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + const __m256 d4 = _mm256_sub_ps(m4, v4); + const __m256 d5 = _mm256_sub_ps(m5, v5); + const __m256 d6 = _mm256_sub_ps(m6, v6); + const __m256 d7 = _mm256_sub_ps(m7, v7); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + distances = _mm256_fmadd_ps(d4, d4, distances); + distances = _mm256_fmadd_ps(d5, d5, distances); + distances = _mm256_fmadd_ps(d6, d6, distances); + distances = _mm256_fmadd_ps(d7, d7, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 64; + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +#endif + template void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) { __m128 x0 = _mm_loadu_ps(x); @@ -509,9 +827,7 @@ void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) { y += 4; accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y))); y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); + dis[i] = horizontal_sum(accu); } } @@ -892,10 +1208,7 @@ size_t fvec_L2sqr_ny_nearest_D4( for (; i < ny; i++) { __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - - const auto distance = _mm_cvtss_f32(accu); + const float distance = horizontal_sum(accu); if (current_min_distance > distance) { current_min_distance = distance; @@ -1031,23 +1344,9 @@ size_t fvec_L2sqr_ny_nearest_D8( __m256 x0 = _mm256_loadu_ps(x); for (; i < ny; i++) { - __m256 sub = _mm256_sub_ps(x0, _mm256_loadu_ps(y)); - __m256 accu = _mm256_mul_ps(sub, sub); + __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); y += 8; - - // horitontal sum - const __m256 h0 = _mm256_hadd_ps(accu, accu); - const __m256 h1 = _mm256_hadd_ps(h0, h0); - - // extract high and low __m128 regs from __m256 - const __m128 h2 = _mm256_extractf128_ps(h1, 1); - const __m128 h3 = _mm256_castps256_ps128(h1); - - // get a final hsum into all 4 regs - const __m128 h4 = _mm_add_ss(h2, h3); - - // extract f[0] from __m128 - const float distance = _mm_cvtss_f32(h4); + const float distance = horizontal_sum(accu); if (current_min_distance > distance) { current_min_distance = distance; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 55455bd79a..dfe8c13394 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -25,6 +25,7 @@ set(FAISS_TEST_SRC test_simdlib.cpp test_approx_topk.cpp test_RCQ_cropping.cpp + test_distances_simd.cpp ) add_executable(faiss_test ${FAISS_TEST_SRC}) diff --git a/tests/test_distances_simd.cpp b/tests/test_distances_simd.cpp new file mode 100644 index 0000000000..aab914e3cf --- /dev/null +++ b/tests/test_distances_simd.cpp @@ -0,0 +1,109 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +#include + +// reference implementations +void fvec_inner_products_ny_ref( + float* ip, + const float* x, + const float* y, + size_t d, + size_t ny) { + for (size_t i = 0; i < ny; i++) { + ip[i] = faiss::fvec_inner_product(x, y, d); + y += d; + } +} + +void fvec_L2sqr_ny_ref( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + for (size_t i = 0; i < ny; i++) { + dis[i] = faiss::fvec_L2sqr(x, y, d); + y += d; + } +} + +// test templated versions of fvec_L2sqr_ny +TEST(TEST_FVEC_L2SQR_NY, D2) { + // we're using int values in order to get 100% accurate + // results with floats. + std::default_random_engine rng(123); + std::uniform_int_distribution u(0, 32); + + for (const auto dim : {2, 4, 8, 12}) { + std::vector x(dim, 0); + for (size_t i = 0; i < x.size(); i++) { + x[i] = u(rng); + } + + for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { + std::vector y(nrows * dim); + for (size_t i = 0; i < y.size(); i++) { + y[i] = u(rng); + } + + std::vector distances(nrows, 0); + faiss::fvec_L2sqr_ny( + distances.data(), x.data(), y.data(), dim, nrows); + + std::vector distances_ref(nrows, 0); + fvec_L2sqr_ny_ref( + distances_ref.data(), x.data(), y.data(), dim, nrows); + + ASSERT_EQ(distances, distances_ref) + << "Mismatching results for dim = " << dim + << ", nrows = " << nrows; + } + } +} + +// fvec_inner_products_ny +TEST(TEST_FVEC_INNER_PRODUCTS_NY, D2) { + // we're using int values in order to get 100% accurate + // results with floats. + std::default_random_engine rng(123); + std::uniform_int_distribution u(0, 32); + + for (const auto dim : {2, 4, 8, 12}) { + std::vector x(dim, 0); + for (size_t i = 0; i < x.size(); i++) { + x[i] = u(rng); + } + + for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { + std::vector y(nrows * dim); + for (size_t i = 0; i < y.size(); i++) { + y[i] = u(rng); + } + + std::vector distances(nrows, 0); + faiss::fvec_inner_products_ny( + distances.data(), x.data(), y.data(), dim, nrows); + + std::vector distances_ref(nrows, 0); + fvec_inner_products_ny_ref( + distances_ref.data(), x.data(), y.data(), dim, nrows); + + ASSERT_EQ(distances, distances_ref) + << "Mismatching results for dim = " << dim + << ", nrows = " << nrows; + } + } +}