Skip to content

Commit

Permalink
Make infinity work under linux/arm64 (#2082)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Issue link:#1720

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
yingfeng authored Oct 21, 2024
1 parent 5e89468 commit c693828
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 81 deletions.
14 changes: 13 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,22 @@ endif ()

MESSAGE(STATUS "C++ Compilation flags: " ${CMAKE_CXX_FLAGS})

if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*")
set(ARM64 TRUE)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*")
set(ARM64 TRUE)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
set(X86_64 TRUE)
endif()

#add_definitions(-march=native)
add_definitions(-DSIMDE_ENABLE_NATIVE_ALIASES)
if (CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL "18.0")
add_definitions(-mevex512)
if(X86_64)
add_definitions(-mevex512)
else()
add_definitions(-march=native)
endif()
endif ()

execute_process(
Expand Down
136 changes: 74 additions & 62 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,52 @@ add_subdirectory(parser)
# add_definitions(-msse4.2 -mfma)
# add_definitions(-mavx2 -mf16c -mpopcnt)


if(APPLE)
execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep FMA"
RESULT_VARIABLE SUPPORT_FMA
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep SSE4.2"
RESULT_VARIABLE SUPPORT_SSE42
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep AVX2"
RESULT_VARIABLE SUPPORT_AVX2
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep AVX512"
RESULT_VARIABLE SUPPORT_AVX512
OUTPUT_QUIET
ERROR_QUIET)
if(X86_64)
execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep FMA"
RESULT_VARIABLE SUPPORT_FMA
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep SSE4.2"
RESULT_VARIABLE SUPPORT_SSE42
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep AVX2"
RESULT_VARIABLE SUPPORT_AVX2
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep AVX512"
RESULT_VARIABLE SUPPORT_AVX512
OUTPUT_QUIET
ERROR_QUIET)
endif()
else()
#Linux
execute_process(COMMAND grep -q fma /proc/cpuinfo
RESULT_VARIABLE SUPPORT_FMA
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND grep -q sse4_2 /proc/cpuinfo
RESULT_VARIABLE SUPPORT_SSE42
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND grep -q avx2 /proc/cpuinfo
RESULT_VARIABLE SUPPORT_AVX2
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND grep -q avx512 /proc/cpuinfo
RESULT_VARIABLE SUPPORT_AVX512
OUTPUT_QUIET
ERROR_QUIET)
if(X86_64)
execute_process(COMMAND grep -q fma /proc/cpuinfo
RESULT_VARIABLE SUPPORT_FMA
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND grep -q sse4_2 /proc/cpuinfo
RESULT_VARIABLE SUPPORT_SSE42
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND grep -q avx2 /proc/cpuinfo
RESULT_VARIABLE SUPPORT_AVX2
OUTPUT_QUIET
ERROR_QUIET)

execute_process(COMMAND grep -q avx512 /proc/cpuinfo
RESULT_VARIABLE SUPPORT_AVX512
OUTPUT_QUIET
ERROR_QUIET)
endif()
endif()


Expand Down Expand Up @@ -283,22 +288,24 @@ target_include_directories(infinity_core PUBLIC "${CMAKE_SOURCE_DIR}/third_party
target_include_directories(infinity_core PUBLIC "${CMAKE_BINARY_DIR}/third_party/pcre2")


if (NOT SUPPORT_FMA EQUAL 0)
message(FATAL_ERROR "This project requires the processor support fused multiply-add (FMA) instructions.")
endif ()

if (NOT SUPPORT_SSE42 EQUAL 0)
message(FATAL_ERROR "This project requires the processor support sse4_2 instructions.")
endif ()

if (SUPPORT_AVX2 EQUAL 0 OR SUPPORT_AVX512 EQUAL 0)
message("Compiled by AVX2 or AVX512")
add_definitions(-march=native)
target_compile_options(infinity_core PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-march=native>)
else ()
message("Compiled by SSE")
add_definitions(-msse4.2 -mfma)
target_compile_options(infinity_core PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-msse4.2 -mfma>)
if(X86_64)
if (NOT SUPPORT_FMA EQUAL 0)
message(FATAL_ERROR "This project requires the processor support fused multiply-add (FMA) instructions.")
endif ()

if (NOT SUPPORT_SSE42 EQUAL 0)
message(FATAL_ERROR "This project requires the processor support sse4_2 instructions.")
endif ()

if (SUPPORT_AVX2 EQUAL 0 OR SUPPORT_AVX512 EQUAL 0)
message("Compiled by AVX2 or AVX512")
add_definitions(-march=native)
target_compile_options(infinity_core PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-march=native>)
else ()
message("Compiled by SSE")
add_definitions(-msse4.2 -mfma)
target_compile_options(infinity_core PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-msse4.2 -mfma>)
endif ()
endif ()

add_executable(infinity
Expand Down Expand Up @@ -622,12 +629,17 @@ target_include_directories(unit_test PUBLIC "${CMAKE_BINARY_DIR}/third_party/pcr


# target_compile_options(unit_test PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma -mf16c -mpopcnt>)
if (SUPPORT_AVX2 EQUAL 0 OR SUPPORT_AVX512 EQUAL 0)
message("Compiled by AVX2 or AVX512")
add_definitions(-mavx2 -mfma -mf16c -mpopcnt)
target_compile_options(unit_test PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma -mf16c -mpopcnt>)
else ()
message("Compiled by SSE")
add_definitions(-msse4.2 -mfma)
target_compile_options(unit_test PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-msse4.2 -mfma>)
if(X86_64)
if (SUPPORT_AVX2 EQUAL 0 OR SUPPORT_AVX512 EQUAL 0)
message("Compiled by AVX2 or AVX512")
add_definitions(-mavx2 -mfma -mf16c -mpopcnt)
target_compile_options(unit_test PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma -mf16c -mpopcnt>)
else ()
message("Compiled by SSE")
add_definitions(-msse4.2 -mfma)
target_compile_options(unit_test PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-msse4.2 -mfma>)
endif ()
else()
add_definitions(-march=native)
endif ()

18 changes: 18 additions & 0 deletions src/common/simd/diskann_simd_func.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ export module diskann_simd_func;

namespace infinity {

#if defined(__aarch64__)
inline float hsum256_ps_avx(__m256 v) {
const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v));
const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
return _mm_cvtss_f32(x32);
}

inline float hsum_ps_sse1(__m128 v) { // v = [ D C | B A ]
__m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1)); // [ C D | A B ]
__m128 sums = _mm_add_ps(v, shuf); // sums = [ D+C C+D | B+A A+B ]
shuf = _mm_movehl_ps(shuf, sums); // [ C D | D+C C+D ] // let the
// compiler avoid a mov by reusing shuf
sums = _mm_add_ss(sums, shuf);
return _mm_cvtss_f32(sums);
}
#endif

export float hsumFloatVec(const float* array, size_t size) {
float sum = 0.0f;
size_t i = 0;
Expand Down
8 changes: 4 additions & 4 deletions src/common/simd/search_top_1_sgemm.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ void inner_search_top_1_with_sgemm_sse2(u32 dimension,
u32 x_id = i + x_part_begin;
float *ip_line = x_y_inner_product_buffer.get() + i * y_part_size;

_mm_prefetch(ip_line, _MM_HINT_NTA);
_mm_prefetch(ip_line + 8, _MM_HINT_NTA);
_mm_prefetch((const char *)(ip_line), _MM_HINT_NTA);
_mm_prefetch((const char *)(ip_line + 8), _MM_HINT_NTA);

const __m128 mul_minus2 = _mm_set1_ps(-2);

Expand All @@ -214,8 +214,8 @@ void inner_search_top_1_with_sgemm_sse2(u32 dimension,
u32 j = 0;
for (; j < (y_part_size / 8) * 8; j += 8, ip_line += 8) {
u32 j_id = j + y_part_begin;
_mm_prefetch(ip_line + 16, _MM_HINT_NTA);
_mm_prefetch(ip_line + 24, _MM_HINT_NTA);
_mm_prefetch((const char *)(ip_line + 16), _MM_HINT_NTA);
_mm_prefetch((const char *)(ip_line + 24), _MM_HINT_NTA);

__m128 y_norm_0 = _mm_loadu_ps(square_y.get() + j_id);
__m128 y_norm_1 = _mm_loadu_ps(square_y.get() + j_id + 4);
Expand Down
8 changes: 4 additions & 4 deletions src/common/simd/search_top_k_sgemm.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ void inner_search_top_k_with_sgemm_sse2(u32 k,
u32 x_id = i + x_part_begin;
float *ip_line = x_y_inner_product_buffer.get() + i * y_part_size;

_mm_prefetch(ip_line, _MM_HINT_NTA);
_mm_prefetch(ip_line + 8, _MM_HINT_NTA);
_mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
_mm_prefetch((const char*)(ip_line + 8), _MM_HINT_NTA);

const __m128 x_norm = _mm_set1_ps(square_x[x_id]);
const __m128 mul_minus2 = _mm_set1_ps(-2);
Expand All @@ -189,8 +189,8 @@ void inner_search_top_k_with_sgemm_sse2(u32 k,
u32 j = 0;
for (; j < (y_part_size / 8) * 8; j += 8, ip_line += 8) {
u32 j_id = j + y_part_begin;
_mm_prefetch(ip_line + 16, _MM_HINT_NTA);
_mm_prefetch(ip_line + 24, _MM_HINT_NTA);
_mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
_mm_prefetch((const char*)(ip_line + 24), _MM_HINT_NTA);

const __m128 y_norm_0 = _mm_loadu_ps(square_y.get() + j_id + 0);
const __m128 y_norm_1 = _mm_loadu_ps(square_y.get() + j_id + 4);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ public:

void Prefetch(SizeT idx, const Meta &meta) const {
const SparseVecEle &vec = vecs_[idx];
_mm_prefetch(vec.indices_.get(), _MM_HINT_T0);
_mm_prefetch(vec.data_.get(), _MM_HINT_T0);
_mm_prefetch((const char*)vec.indices_.get(), _MM_HINT_T0);
_mm_prefetch((const char*)vec.data_.get(), _MM_HINT_T0);
}

private:
Expand Down
2 changes: 1 addition & 1 deletion src/storage/knn_index/sparse/bmp_alg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module;

#include <algorithm>
#include <vector>
#include <xmmintrin.h>
#include "common/simd/simd_common_intrin_include.h"

module bmp_alg;

Expand Down
2 changes: 1 addition & 1 deletion src/storage/knn_index/sparse/bmp_alg_serialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
module;

#include <vector>
#include <xmmintrin.h>
#include "common/simd/simd_common_intrin_include.h"

module bmp_alg;

Expand Down
2 changes: 1 addition & 1 deletion src/storage/knn_index/sparse/bmp_blockterms.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

module;

#include <xmmintrin.h>
#include "common/simd/simd_common_intrin_include.h"

export module bmp_blockterms;

Expand Down
8 changes: 4 additions & 4 deletions src/storage/knn_index/sparse/bmp_posting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

module;

#include <xmmintrin.h>
#include "common/simd/simd_common_intrin_include.h"

module bm_posting;

Expand All @@ -41,8 +41,8 @@ void BlockData<DataType, BMPCompressType::kCompressed>::AddBlock(BMPBlockID bloc

template <typename DataType>
void BlockData<DataType, BMPCompressType::kCompressed>::Prefetch() const {
_mm_prefetch(block_ids_.data(), _MM_HINT_T0);
_mm_prefetch(max_scores_.data(), _MM_HINT_T0);
_mm_prefetch((const char*)block_ids_.data(), _MM_HINT_T0);
_mm_prefetch((const char*)max_scores_.data(), _MM_HINT_T0);
}

template struct BlockData<f32, BMPCompressType::kCompressed>;
Expand All @@ -67,7 +67,7 @@ void BlockData<DataType, BMPCompressType::kRaw>::AddBlock(BMPBlockID block_id, D

template <typename DataType>
void BlockData<DataType, BMPCompressType::kRaw>::Prefetch() const {
_mm_prefetch(max_scores_.data(), _MM_HINT_T0);
_mm_prefetch((const char*)max_scores_.data(), _MM_HINT_T0);
}

template struct BlockData<f32, BMPCompressType::kRaw>;
Expand Down
11 changes: 10 additions & 1 deletion src/unit_test/storage/knnindex/emvb_search/test_simd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include <cstdlib>
#include <immintrin.h>
#include "common/simd/simd_common_intrin_include.h"

#include "gtest/gtest.h"
import base_test;
Expand All @@ -25,6 +25,15 @@ using namespace infinity;

class SIMDTest : public BaseTest {};

#if defined(__aarch64__)
inline float hsum256_ps_avx(__m256 v) {
const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v));
const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
return _mm_cvtss_f32(x32);
}
#endif

TEST_F(SIMDTest, testsum256) {
constexpr u32 test_sum256_loop = 20;

Expand Down

0 comments on commit c693828

Please sign in to comment.