diff --git a/cmake/libs/libfaiss.cmake b/cmake/libs/libfaiss.cmake index 1e99fc6ac..4e20dde78 100644 --- a/cmake/libs/libfaiss.cmake +++ b/cmake/libs/libfaiss.cmake @@ -1,3 +1,5 @@ +include(CheckCXXCompilerFlag) + knowhere_file_glob( GLOB FAISS_SRCS thirdparty/faiss/faiss/*.cpp thirdparty/faiss/faiss/impl/*.cpp thirdparty/faiss/faiss/invlists/*.cpp @@ -47,12 +49,50 @@ if(__X86_64) endif() if(__AARCH64) - set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc - src/simd/distances_neon.cc) - add_library(knowhere_utils STATIC ${UTILS_SRC}) + + set(UTILS_SRC src/simd/distances_ref.cc src/simd/distances_neon.cc) + set(UTILS_SVE_SRC src/simd/hook.cc src/simd/distances_sve.cc) + set(ALL_UTILS_SRC ${UTILS_SRC} ${UTILS_SVE_SRC}) + + add_library( + knowhere_utils STATIC + ${ALL_UTILS_SRC} + ) + + check_cxx_compiler_flag("-march=armv9-a+sve" HAS_ARMV9_SVE) + if (HAS_ARMV9_SVE) + message(STATUS "SVE for ARMv9: Found") + else() + message(STATUS "SVE for ARMv9: Not Found") + endif() + + check_cxx_compiler_flag("-march=armv8-a+sve" HAS_ARMV8_SVE) + if (HAS_ARMV8_SVE) + message(STATUS "SVE for ARMv8: Found") + else() + message(STATUS "SVE for ARMv8: Not Found") + endif() + + if (HAS_ARMV9_SVE) + foreach(SVE_FILE ${UTILS_SVE_SRC}) + set_source_files_properties(${SVE_FILE} PROPERTIES COMPILE_OPTIONS "-march=armv9-a+sve") + target_compile_options(knowhere_utils PRIVATE -march=armv9-a) + endforeach() + elseif (HAS_ARMV8_SVE) + foreach(SVE_FILE ${UTILS_SVE_SRC}) + set_source_files_properties(${SVE_FILE} PROPERTIES COMPILE_OPTIONS "-march=armv8-a+sve") + target_compile_options(knowhere_utils PRIVATE -march=armv8-a) + endforeach() + else() + message(WARNING "SVE not supported on this platform.") + target_compile_options(knowhere_utils PRIVATE -march=armv8-a) + endif() + target_link_libraries(knowhere_utils PUBLIC glog::glog) endif() + + # ToDo: Add distances_vsx.cc for powerpc64 SIMD acceleration if(__PPC64) set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc src/simd/distances_powerpc.cc) diff --git a/src/simd/distances_sve.cc b/src/simd/distances_sve.cc new file mode 100644 index 000000000..5051ebec7 --- /dev/null +++ b/src/simd/distances_sve.cc @@ -0,0 +1,219 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "distances_sve.h" + +#include + +#include + +#include "faiss/impl/platform_macros.h" +#if defined(__ARM_FEATURE_SVE) +namespace faiss { + +float +fvec_L2sqr_sve(const float* x, const float* y, size_t d) { + svfloat32_t sum = svdup_f32(0.0f); + size_t i = 0; + + svbool_t pg = svptrue_b32(); + + while (i < d) { + if (d - i < svcntw()) + pg = svwhilelt_b32(i, d); + + svfloat32_t a = svld1_f32(pg, x + i); + svfloat32_t b = svld1_f32(pg, y + i); + svfloat32_t diff = svsub_f32_m(pg, a, b); + sum = svmla_f32_m(pg, sum, diff, diff); + i += svcntw(); + } + + return svaddv_f32(svptrue_b32(), sum); +} + +float +fvec_L1_sve(const float* x, const float* y, size_t d) { + svfloat32_t sum = svdup_f32(0.0f); + size_t i = 0; + + svbool_t pg = svptrue_b32(); + + while (i < d) { + if (d - i < svcntw()) + pg = svwhilelt_b32(i, d); + + svfloat32_t a = svld1_f32(pg, x + i); + svfloat32_t b = svld1_f32(pg, y + i); + svfloat32_t diff = svabs_f32_x(pg, svsub_f32_m(pg, a, b)); + sum = svadd_f32_m(pg, sum, diff); + i += svcntw(); + } + + return svaddv_f32(svptrue_b32(), sum); +} + +float +fvec_Linf_sve(const float* x, const float* y, size_t d) { + svfloat32_t max_val = svdup_f32(0.0f); + size_t i = 0; + + svbool_t pg = svptrue_b32(); + + while (i < d) { + if (d - i < svcntw()) + pg = svwhilelt_b32(i, d); + + svfloat32_t a = svld1_f32(pg, x + i); + svfloat32_t b = svld1_f32(pg, y + i); + svfloat32_t diff = svabs_f32_x(pg, svsub_f32_m(pg, a, b)); + max_val = svmax_f32_m(pg, max_val, diff); + i += svcntw(); + } + + return svmaxv_f32(svptrue_b32(), max_val); +} + +float +fvec_norm_L2sqr_sve(const float* x, size_t d) { + svfloat32_t sum = svdup_f32(0.0f); + size_t i = 0; + + svbool_t pg = svptrue_b32(); + + while (i < d) { + if (d - i < svcntw()) + pg = svwhilelt_b32(i, d); + + svfloat32_t a = svld1_f32(pg, x + i); + sum = svmla_f32_m(pg, sum, a, a); + i += svcntw(); + } + + return svaddv_f32(svptrue_b32(), sum); +} + +void +fvec_madd_sve(size_t n, const float* a, float bf, const float* b, float* c) { + size_t i = 0; + svfloat32_t bf_vec = svdup_f32(bf); + + svbool_t pg = svptrue_b32(); + + while (i < n) { + if (n - i < svcntw()) + pg = svwhilelt_b32(i, n); + + svfloat32_t a_vec = svld1_f32(pg, a + i); + svfloat32_t b_vec = svld1_f32(pg, b + i); + svfloat32_t c_vec = svmla_f32_m(pg, a_vec, b_vec, bf_vec); + svst1_f32(pg, c + i, c_vec); + i += svcntw(); + } +} + +int +fvec_madd_and_argmin_sve(size_t n, const float* a, float bf, const float* b, float* c) { + size_t i = 0; + svfloat32_t min_val = svdup_f32(INFINITY); + svuint32_t min_idx = svdup_u32(0); + svuint32_t idx_base = svindex_u32(0, 1); + + svfloat32_t bf_vec = svdup_f32(bf); + svbool_t pg = svptrue_b32(); + + while (i < n) { + if (n - i < svcntw()) + pg = svwhilelt_b32(i, n); + + svuint32_t idx = svadd_u32_z(pg, idx_base, svdup_u32(i)); + svfloat32_t a_vec = svld1_f32(pg, a + i); + svfloat32_t b_vec = svld1_f32(pg, b + i); + svfloat32_t c_vec = svmla_f32_m(pg, a_vec, b_vec, bf_vec); + svst1_f32(pg, c + i, c_vec); + + svbool_t cmp = svcmplt(pg, c_vec, min_val); + min_val = svsel_f32(cmp, c_vec, min_val); + min_idx = svsel_u32(cmp, idx, min_idx); + + i += svcntw(); + } + + float min_value = svminv_f32(svptrue_b32(), min_val); + svbool_t pg_min = svcmpeq(svptrue_b32(), min_val, svdup_f32(min_value)); + uint32_t min_index = svlastb_u32(pg_min, min_idx); + + return static_cast(min_index); +} + +void +fvec_inner_product_batch_4_sve(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) { + svfloat32_t acc0 = svdup_f32(0.0f); + svfloat32_t acc1 = svdup_f32(0.0f); + svfloat32_t acc2 = svdup_f32(0.0f); + svfloat32_t acc3 = svdup_f32(0.0f); + + size_t i = 0; + svbool_t pg = svptrue_b32(); + + while (i < d) { + if (d - i < svcntw()) + pg = svwhilelt_b32(i, d); + + svfloat32_t vx = svld1(pg, &x[i]); + svfloat32_t vy0 = svld1(pg, &y0[i]); + svfloat32_t vy1 = svld1(pg, &y1[i]); + svfloat32_t vy2 = svld1(pg, &y2[i]); + svfloat32_t vy3 = svld1(pg, &y3[i]); + + acc0 = svmla_f32_m(pg, acc0, vx, vy0); + acc1 = svmla_f32_m(pg, acc1, vx, vy1); + acc2 = svmla_f32_m(pg, acc2, vx, vy2); + acc3 = svmla_f32_m(pg, acc3, vx, vy3); + + i += svcntw(); + } + + dis0 = svaddv_f32(svptrue_b32(), acc0); + dis1 = svaddv_f32(svptrue_b32(), acc1); + dis2 = svaddv_f32(svptrue_b32(), acc2); + dis3 = svaddv_f32(svptrue_b32(), acc3); +} + +void +fvec_L2sqr_batch_4_sve(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + + for (size_t i = 0; i < d; ++i) { + const float q0 = x[i] - y0[i]; + const float q1 = x[i] - y1[i]; + const float q2 = x[i] - y2[i]; + const float q3 = x[i] - y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} + +} // namespace faiss + +#endif diff --git a/src/simd/distances_sve.h b/src/simd/distances_sve.h new file mode 100644 index 000000000..2d63e7ef3 --- /dev/null +++ b/src/simd/distances_sve.h @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#if defined(__ARM_FEATURE_SVE) +namespace faiss { + +float +fvec_L2sqr_sve(const float* x, const float* y, size_t d); + +float +fvec_L1_sve(const float* x, const float* y, size_t d); + +float +fvec_Linf_sve(const float* x, const float* y, size_t d); + +float +fvec_norm_L2sqr_sve(const float* x, size_t d); + +void +fvec_madd_sve(size_t n, const float* a, float bf, const float* b, float* c); + +int +fvec_madd_and_argmin_sve(size_t n, const float* a, float bf, const float* b, float* c); + +int32_t +ivec_L2sqr_sve(const int8_t* x, const int8_t* y, size_t d); + +void +fvec_inner_product_batch_4_sve(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); + +void +fvec_L2sqr_batch_4_sve(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); + +} // namespace faiss +#endif diff --git a/src/simd/hook.cc b/src/simd/hook.cc index 813c30b96..0c8acb952 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -26,10 +26,19 @@ #include "distances_neon.h" #endif +#if defined(__ARM_FEATURE_SVE) +#include "distances_sve.h" +#endif + #if defined(__powerpc64__) #include "distances_powerpc.h" #endif +#if defined(__aarch64__) +#include +#include +#endif + #include "distances_ref.h" namespace faiss { @@ -117,6 +126,14 @@ cpu_support_f16c() { } #endif +#if defined(__aarch64__) +bool +supports_sve() { + unsigned long hwcap = getauxval(AT_HWCAP); + return (hwcap & HWCAP_SVE) != 0; +} +#endif + static std::mutex patch_bf16_mutex; void @@ -146,12 +163,19 @@ enable_patch_for_fp32_bf16() { fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_bf16_patch_ref; } #endif -#if defined(__ARM_NEON) + +#if defined(__aarch64__) + +#if defined(__ARM_NEON) && !defined(__ARM_FEATURE_SVE) + fvec_inner_product = fvec_inner_product_bf16_patch_neon; fvec_inner_product_batch_4 = fvec_inner_product_batch_4_bf16_patch_neon; fvec_L2sqr = fvec_L2sqr_bf16_patch_neon; fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_bf16_patch_neon; + +#endif + #endif } @@ -378,43 +402,82 @@ fvec_hook(std::string& simd_type) { } #endif -#if defined(__ARM_NEON) - fvec_inner_product = fvec_inner_product_neon; - fvec_L2sqr = fvec_L2sqr_neon; - fvec_L1 = fvec_L1_neon; - fvec_Linf = fvec_Linf_neon; +#if defined(__aarch64__) + if (supports_sve()) { +#if defined(__ARM_FEATURE_SVE) + // ToDo: Enable remaining functions on SVE + fvec_L2sqr = fvec_L2sqr_sve; + fvec_L1 = fvec_L1_sve; + fvec_Linf = fvec_Linf_sve; + fvec_norm_L2sqr = fvec_norm_L2sqr_sve; + fvec_madd = fvec_madd_sve; + fvec_madd_and_argmin = fvec_madd_and_argmin_sve; - fvec_norm_L2sqr = fvec_norm_L2sqr_neon; - fvec_L2sqr_ny = fvec_L2sqr_ny_neon; - fvec_inner_products_ny = fvec_inner_products_ny_neon; - fvec_madd = fvec_madd_neon; - fvec_madd_and_argmin = fvec_madd_and_argmin_neon; + fvec_inner_product = fvec_inner_product_neon; + fvec_L2sqr_ny = fvec_L2sqr_ny_neon; + fvec_inner_products_ny = fvec_inner_products_ny_neon; - fvec_inner_product_batch_4 = fvec_inner_product_batch_4_neon; - fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_neon; + ivec_inner_product = ivec_inner_product_neon; + ivec_L2sqr = ivec_L2sqr_neon; - ivec_inner_product = ivec_inner_product_neon; - ivec_L2sqr = ivec_L2sqr_neon; + // fp16 + fp16_vec_inner_product = fp16_vec_inner_product_neon; + fp16_vec_L2sqr = fp16_vec_L2sqr_neon; + fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_neon; - // fp16 - fp16_vec_inner_product = fp16_vec_inner_product_neon; - fp16_vec_L2sqr = fp16_vec_L2sqr_neon; - fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_neon; + // bf16 + bf16_vec_inner_product = bf16_vec_inner_product_neon; + bf16_vec_L2sqr = bf16_vec_L2sqr_neon; + bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_neon; + fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_sve; + fvec_inner_product_batch_4 = fvec_inner_product_batch_4_sve; - fp16_vec_inner_product_batch_4 = fp16_vec_inner_product_batch_4_neon; - fp16_vec_L2sqr_batch_4 = fp16_vec_L2sqr_batch_4_neon; + simd_type = "SVE"; + support_pq_fast_scan = true; +#endif + } else { +#if defined(__ARM_NEON) + // NEON functions + fvec_inner_product = fvec_inner_product_neon; + fvec_L2sqr = fvec_L2sqr_neon; + fvec_L1 = fvec_L1_neon; + fvec_Linf = fvec_Linf_neon; + fvec_norm_L2sqr = fvec_norm_L2sqr_neon; + fvec_L2sqr_ny = fvec_L2sqr_ny_neon; + fvec_inner_products_ny = fvec_inner_products_ny_neon; + fvec_madd = fvec_madd_neon; + fvec_madd_and_argmin = fvec_madd_and_argmin_neon; + + ivec_inner_product = ivec_inner_product_neon; + ivec_L2sqr = ivec_L2sqr_neon; + + fvec_inner_product_batch_4 = fvec_inner_product_batch_4_neon; + fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_neon; + + ivec_inner_product = ivec_inner_product_neon; + ivec_L2sqr = ivec_L2sqr_neon; - // bf16 - bf16_vec_inner_product = bf16_vec_inner_product_neon; - bf16_vec_L2sqr = bf16_vec_L2sqr_neon; - bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_neon; + // fp16 + fp16_vec_inner_product = fp16_vec_inner_product_neon; + fp16_vec_L2sqr = fp16_vec_L2sqr_neon; + fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_neon; - bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_neon; - bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_neon; + fp16_vec_inner_product_batch_4 = fp16_vec_inner_product_batch_4_neon; + fp16_vec_L2sqr_batch_4 = fp16_vec_L2sqr_batch_4_neon; - // - simd_type = "NEON"; - support_pq_fast_scan = true; + // bf16 + bf16_vec_inner_product = bf16_vec_inner_product_neon; + bf16_vec_L2sqr = bf16_vec_L2sqr_neon; + bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_neon; + + bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_neon; + bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_neon; + + // + simd_type = "NEON"; + support_pq_fast_scan = true; +#endif + } #endif // ToDo MG: include VSX intrinsics via distances_vsx once _ref tests succeed diff --git a/src/simd/hook.h b/src/simd/hook.h index ed6f22546..b2d4e03e3 100644 --- a/src/simd/hook.h +++ b/src/simd/hook.h @@ -130,6 +130,11 @@ bool cpu_support_f16c(); #endif +#if defined(__aarch64__) +bool +supports_sve(); +#endif + void enable_patch_for_fp32_bf16(); diff --git a/tests/ut/test_knowhere_init.cc b/tests/ut/test_knowhere_init.cc index 2f1e7d3d7..0e8b6c496 100644 --- a/tests/ut/test_knowhere_init.cc +++ b/tests/ut/test_knowhere_init.cc @@ -58,7 +58,7 @@ TEST_CASE("Knowhere global config", "[init]") { } TEST_CASE("Knowhere SIMD config", "[simd]") { - std::vector v = {"AVX512", "AVX2", "SSE4_2", "GENERIC", "NEON"}; + std::vector v = {"AVX512", "AVX2", "SSE4_2", "GENERIC", "NEON", "SVE"}; std::unordered_set s(v.begin(), v.end()); auto res = knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX512);