Skip to content

Commit

Permalink
Enable SVE Support for L2 Metric Computation in FP32
Browse files Browse the repository at this point in the history
kind/feature
Signed-off-by:Adarsh Srivastava <adarsh.srivastava@fujitsu.com>

Signed-off-by: Adarsh Srivastava <Adarsh.Srivastava@fujitsu.com>
  • Loading branch information
adarshs1310 committed Nov 29, 2024
1 parent 1cb9937 commit fe28988
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 29 deletions.
28 changes: 28 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,31 @@ install(TARGETS knowhere
DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
install(DIRECTORY "${PROJECT_SOURCE_DIR}/include/knowhere"
DESTINATION "${CMAKE_INSTALL_PREFIX}/include")

if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")

# Find ARM SVE headers
find_path(ARM_SVE_DIR arm_sve.h PATHS
/usr/lib/gcc/aarch64-linux-gnu/*/include
/usr/lib/llvm-*/lib/clang/*/include
/usr/include
/usr/local/include
NO_DEFAULT_PATH
)
if(ARM_SVE_DIR)
include_directories(SYSTEM ${ARM_SVE_DIR})
endif()

# Find ARM NEON headers
find_path(ARM_NEON_DIR arm_neon.h PATHS
/usr/lib/gcc/aarch64-linux-gnu/*/include
/usr/lib/llvm-*/lib/clang/*/include
/usr/include
/usr/local/include
NO_DEFAULT_PATH
)
if(ARM_NEON_DIR)
include_directories(SYSTEM ${ARM_NEON_DIR})
endif()

endif()
2 changes: 1 addition & 1 deletion cmake/libs/libfaiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ endif()

if(__AARCH64)
set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc
src/simd/distances_neon.cc)
src/simd/distances_neon.cc src/simd/distances_sve.cc)
add_library(knowhere_utils STATIC ${UTILS_SRC})
target_link_libraries(knowhere_utils PUBLIC glog::glog)
endif()
Expand Down
178 changes: 178 additions & 0 deletions src/simd/distances_sve.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.


#include "distances_sve.h"
#include <arm_sve.h>
#include <cmath>
#include "faiss/impl/platform_macros.h"
#include "simd_util.h"
#pragma GCC optimize("O3,fast-math,inline")
#if defined(__ARM_FEATURE_SVE)
namespace faiss {

float fvec_L2sqr_sve(const float* x, const float* y, size_t d) {
svfloat32_t sum = svdup_f32(0.0f);
size_t i = 0;

svbool_t pg = svptrue_b32();

while (i < d) {
if (d - i < svcntw())
pg = svwhilelt_b32(i, d);

svfloat32_t a = svld1_f32(pg, x + i);
svfloat32_t b = svld1_f32(pg, y + i);
svfloat32_t diff = svsub_f32_m(pg, a, b);
sum = svmla_f32_m(pg, sum, diff, diff);
i += svcntw();
}

return svaddv_f32(svptrue_b32(), sum);
}

float fvec_L1_sve(const float* x, const float* y, size_t d) {
svfloat32_t sum = svdup_f32(0.0f);
size_t i = 0;

svbool_t pg = svptrue_b32();

while (i < d) {
if (d - i < svcntw())
pg = svwhilelt_b32(i, d);

svfloat32_t a = svld1_f32(pg, x + i);
svfloat32_t b = svld1_f32(pg, y + i);
svfloat32_t diff = svabs_f32_x(pg, svsub_f32_m(pg, a, b));
sum = svadd_f32_m(pg, sum, diff);
i += svcntw();
}

return svaddv_f32(svptrue_b32(), sum);
}

float fvec_Linf_sve(const float* x, const float* y, size_t d) {
svfloat32_t max_val = svdup_f32(0.0f);
size_t i = 0;

svbool_t pg = svptrue_b32();

while (i < d) {
if (d - i < svcntw())
pg = svwhilelt_b32(i, d);

svfloat32_t a = svld1_f32(pg, x + i);
svfloat32_t b = svld1_f32(pg, y + i);
svfloat32_t diff = svabs_f32_x(pg, svsub_f32_m(pg, a, b));
max_val = svmax_f32_m(pg, max_val, diff);
i += svcntw();
}

return svmaxv_f32(svptrue_b32(), max_val);
}

float fvec_norm_L2sqr_sve(const float* x, size_t d) {
svfloat32_t sum = svdup_f32(0.0f);
size_t i = 0;

svbool_t pg = svptrue_b32();

while (i < d) {
if (d - i < svcntw())
pg = svwhilelt_b32(i, d);

svfloat32_t a = svld1_f32(pg, x + i);
sum = svmla_f32_m(pg, sum, a, a);
i += svcntw();
}

return svaddv_f32(svptrue_b32(), sum);
}

void fvec_madd_sve(size_t n, const float* a, float bf, const float* b, float* c) {
size_t i = 0;
svfloat32_t bf_vec = svdup_f32(bf);

svbool_t pg = svptrue_b32();

while (i < n) {
if (n - i < svcntw())
pg = svwhilelt_b32(i, n);

svfloat32_t a_vec = svld1_f32(pg, a + i);
svfloat32_t b_vec = svld1_f32(pg, b + i);
svfloat32_t c_vec = svmla_f32_m(pg, a_vec, b_vec, bf_vec);
svst1_f32(pg, c + i, c_vec);
i += svcntw();
}
}

int fvec_madd_and_argmin_sve(size_t n, const float* a, float bf, const float* b, float* c) {
size_t i = 0;
svfloat32_t min_val = svdup_f32(INFINITY);
svuint32_t min_idx = svdup_u32(0);
svuint32_t idx_base = svindex_u32(0, 1);

svfloat32_t bf_vec = svdup_f32(bf);
svbool_t pg = svptrue_b32();

while (i < n) {
if (n - i < svcntw())
pg = svwhilelt_b32(i, n);

svuint32_t idx = svadd_u32_z(pg, idx_base, svdup_u32(i));
svfloat32_t a_vec = svld1_f32(pg, a + i);
svfloat32_t b_vec = svld1_f32(pg, b + i);
svfloat32_t c_vec = svmla_f32_m(pg, a_vec, b_vec, bf_vec);
svst1_f32(pg, c + i, c_vec);

svbool_t cmp = svcmplt(pg, c_vec, min_val);
min_val = svsel_f32(cmp, c_vec, min_val);
min_idx = svsel_u32(cmp, idx, min_idx);

i += svcntw();
}

float min_value = svminv_f32(svptrue_b32(), min_val);
svbool_t pg_min = svcmpeq(svptrue_b32(), min_val, svdup_f32(min_value));
uint32_t min_index = svlastb_u32(pg_min, min_idx);

return static_cast<int>(min_index);
}

void fvec_L2sqr_batch_4_sve(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) {
float d0 = 0;
float d1 = 0;
float d2 = 0;
float d3 = 0;

for (size_t i = 0; i < d; ++i) {
const float q0 = x[i] - y0[i];
const float q1 = x[i] - y1[i];
const float q2 = x[i] - y2[i];
const float q3 = x[i] - y3[i];
d0 += q0 * q0;
d1 += q1 * q1;
d2 += q2 * q2;
d3 += q3 * q3;
}

dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}


} // namespace faiss

#endif
37 changes: 37 additions & 0 deletions src/simd/distances_sve.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.


#include <cstdint>
#include <cstdio>
#include <arm_sve.h>
#if defined(__ARM_FEATURE_SVE)
namespace faiss {

float fvec_L2sqr_sve(const float* x, const float* y, size_t d);

float fvec_L1_sve(const float* x, const float* y, size_t d);

float fvec_Linf_sve(const float* x, const float* y, size_t d);

float fvec_norm_L2sqr_sve(const float* x, size_t d);

void fvec_madd_sve(size_t n, const float* a, float bf, const float* b, float* c);

int fvec_madd_and_argmin_sve(size_t n, const float* a, float bf, const float* b, float* c);

int32_t ivec_L2sqr_sve(const int8_t* x, const int8_t* y, size_t d);

void fvec_L2sqr_batch_4_sve(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

} // namespace faiss
#endif
97 changes: 70 additions & 27 deletions src/simd/hook.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

#include "faiss/FaissHook.h"

#if defined(__ARM_FEATURE_SVE)
#include "distances_sve.h"
#endif

#if defined(__ARM_NEON)
#include "distances_neon.h"
#endif
Expand Down Expand Up @@ -132,15 +136,20 @@ enable_patch_for_fp32_bf16() {
fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_ref_bf16_patch;
}
#endif
#if defined(__ARM_NEON)

fvec_inner_product = fvec_inner_product_neon_bf16_patch;
fvec_L2sqr = fvec_L2sqr_neon_bf16_patch;
#if defined(__aarch64__)

#if defined(__ARM_NEON) && !defined(__ARM_FEATURE_SVE)

fvec_inner_product_batch_4 = fvec_inner_product_batch_4_neon_bf16_patch;
fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_neon_bf16_patch;
fvec_L2sqr = fvec_L2sqr_neon_bf16_patch;
fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_neon_bf16_patch;
fvec_inner_product = fvec_inner_product_neon_bf16_patch;
fvec_inner_product_batch_4 = fvec_inner_product_batch_4_neon_bf16_patch;

#endif

#endif

}

void
Expand Down Expand Up @@ -294,37 +303,71 @@ fvec_hook(std::string& simd_type) {
}
#endif

#if defined(__ARM_NEON)
fvec_inner_product = fvec_inner_product_neon;
fvec_L2sqr = fvec_L2sqr_neon;
fvec_L1 = fvec_L1_neon;
fvec_Linf = fvec_Linf_neon;
#if defined(__aarch64__)

fvec_norm_L2sqr = fvec_norm_L2sqr_neon;
fvec_L2sqr_ny = fvec_L2sqr_ny_neon;
fvec_inner_products_ny = fvec_inner_products_ny_neon;
fvec_madd = fvec_madd_neon;
fvec_madd_and_argmin = fvec_madd_and_argmin_neon;
#if defined(__ARM_FEATURE_SVE)
// ToDo: Enable remaining functions on SVE

fvec_L2sqr = fvec_L2sqr_sve;
fvec_L1 = fvec_L1_sve;
fvec_Linf = fvec_Linf_sve;
fvec_norm_L2sqr = fvec_norm_L2sqr_sve;
fvec_madd = fvec_madd_sve;
fvec_madd_and_argmin = fvec_madd_and_argmin_sve;

ivec_inner_product = ivec_inner_product_neon;
ivec_L2sqr = ivec_L2sqr_neon;
fvec_inner_product = fvec_inner_product_neon;
fvec_L2sqr_ny = fvec_L2sqr_ny_neon;
fvec_inner_products_ny = fvec_inner_products_ny_neon;

fp16_vec_inner_product = fp16_vec_inner_product_neon;
fp16_vec_L2sqr = fp16_vec_L2sqr_neon;
fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_neon;

bf16_vec_inner_product = bf16_vec_inner_product_neon;
bf16_vec_L2sqr = bf16_vec_L2sqr_neon;
bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_neon;
ivec_inner_product = ivec_inner_product_neon;
ivec_L2sqr = ivec_L2sqr_neon;

fvec_inner_product_batch_4 = fvec_inner_product_batch_4_neon;
fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_neon;
fp16_vec_inner_product = fp16_vec_inner_product_neon;
fp16_vec_L2sqr = fp16_vec_L2sqr_neon;
fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_neon;

simd_type = "NEON";
support_pq_fast_scan = true;
bf16_vec_inner_product = bf16_vec_inner_product_neon;
bf16_vec_L2sqr = bf16_vec_L2sqr_neon;
bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_neon;
fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_sve;
simd_type = "SVE";
support_pq_fast_scan = true;

#elif defined(__ARM_NEON)
// NEON functions
fvec_inner_product = fvec_inner_product_neon;
fvec_L2sqr = fvec_L2sqr_neon;
fvec_L1 = fvec_L1_neon;
fvec_Linf = fvec_Linf_neon;
fvec_norm_L2sqr = fvec_norm_L2sqr_neon;
fvec_L2sqr_ny = fvec_L2sqr_ny_neon;
fvec_inner_products_ny = fvec_inner_products_ny_neon;
fvec_madd = fvec_madd_neon;
fvec_madd_and_argmin = fvec_madd_and_argmin_neon;

ivec_inner_product = ivec_inner_product_neon;
ivec_L2sqr = ivec_L2sqr_neon;

fp16_vec_inner_product = fp16_vec_inner_product_neon;
fp16_vec_L2sqr = fp16_vec_L2sqr_neon;
fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_neon;

bf16_vec_inner_product = bf16_vec_inner_product_neon;
bf16_vec_L2sqr = bf16_vec_L2sqr_neon;
bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_neon;

fvec_inner_product_batch_4 = fvec_inner_product_batch_4_neon;
fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_neon;

simd_type = "NEON";
support_pq_fast_scan = true;

#endif

#endif


// ToDo MG: include VSX intrinsics via distances_vsx once _ref tests succeed
#if defined(__powerpc64__)
fvec_inner_product = fvec_inner_product_ppc;
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/test_knowhere_init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ TEST_CASE("Knowhere global config", "[init]") {
}

TEST_CASE("Knowhere SIMD config", "[simd]") {
std::vector<std::string> v = {"AVX512", "AVX2", "SSE4_2", "GENERIC", "NEON"};
std::vector<std::string> v = {"AVX512", "AVX2", "SSE4_2", "GENERIC", "NEON", "SVE"};
std::unordered_set<std::string> s(v.begin(), v.end());

auto res = knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX512);
Expand Down

0 comments on commit fe28988

Please sign in to comment.