Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Add Fused-Attention Layer for AVX2 Platforms (#137)
Browse files Browse the repository at this point in the history
* add prefer_fp32

* add avx2 mha

* TARGET_512

* follow clang-tidy

* fix ut on linux

* move intrin utils to btla

* move btla templates to mha_dense_wrapper.h

* add f16c target for avx2 intrin

* move tests to mha_dense_tests.cpp

* fix format
  • Loading branch information
DDEle authored Feb 26, 2024
1 parent 150e752 commit bc5ee16
Show file tree
Hide file tree
Showing 12 changed files with 3,016 additions and 2,290 deletions.
91 changes: 90 additions & 1 deletion bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace avx2 {
#if CompileAVX2()
#ifdef __GNUC__
#pragma GCC push_options
#pragma GCC target("avx2", "fma")
#pragma GCC target("avx2", "fma", "f16c")
#else
#endif

Expand Down Expand Up @@ -1118,6 +1118,95 @@ static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, co
return BTLA_CODE::Success;
}

inline __m256 poly_scale_2nd_ps(const __m256i z, const __m256 f, const __m256 c0, const __m256 c1, const __m256 c2) {
const auto y = _mm256_fmadd_ps(_mm256_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2;
static const auto mask_exp = _mm256_set1_epi32(0x7f800000);
static const auto mask_not_exp = _mm256_set1_epi32(~0x7f800000);

const auto y_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_exp);
const auto y_not_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_not_exp);

const auto y_exp_scaled = _mm256_add_epi32(y_exp, _mm256_slli_epi32(z, 23));
return _mm256_castsi256_ps(_mm256_or_si256(y_not_exp, _mm256_and_si256(y_exp_scaled, mask_exp)));
}

inline __m256 exp_ps_0_1(const __m256 x) {
static const auto c0 = _mm256_set1_ps(0.240226507f);
static const auto c1 = _mm256_set1_ps(0.452920674f);
static const auto c2 = _mm256_set1_ps(0.713483036f);
static const float v_log2e = std::log2(std::exp(1.f));
static const auto log2e = _mm256_set1_ps(v_log2e);
static const auto half = _mm256_set1_ps(.5f);

const auto x1 = _mm256_fmadd_ps(x, log2e, half); // auto x1 = x * log2e + _mm256_set1_ps(.5f);
const auto z = _mm256_floor_ps(x1);
const auto f = _mm256_sub_ps(x1, z); // auto f = x1 - z;

return poly_scale_2nd_ps(_mm256_cvtps_epi32(z), f, c0, c1, c2);
}

#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes" // https://stackoverflow.com/a/49216021
#endif
// Interleave 8 xmm vectors of words inplace
static inline std::array<__m128i, 8> tr_x8_word(std::array<__m128i, 8>& src) { // NOLINT [runtime/references]
std::array<__m128i, 8> dst;

for (int i = 0; i < 8; i += 2) {
dst[i + 0] = _mm_unpacklo_epi16(src[i + 0], src[i + 1]);
dst[i + 1] = _mm_unpackhi_epi16(src[i + 0], src[i + 1]);
}
for (int i = 0; i < 8; i += 4) {
src[i + 0] = _mm_unpacklo_epi32(dst[i + 0], dst[i + 2]);
src[i + 1] = _mm_unpackhi_epi32(dst[i + 0], dst[i + 2]);
src[i + 2] = _mm_unpacklo_epi32(dst[i + 1], dst[i + 3]);
src[i + 3] = _mm_unpackhi_epi32(dst[i + 1], dst[i + 3]);
}
dst[0] = _mm_unpacklo_epi64(src[0], src[4]);
dst[1] = _mm_unpackhi_epi64(src[0], src[4]);
dst[2] = _mm_unpacklo_epi64(src[1], src[5]);
dst[3] = _mm_unpackhi_epi64(src[1], src[5]);
dst[4] = _mm_unpacklo_epi64(src[2], src[6]);
dst[5] = _mm_unpackhi_epi64(src[2], src[6]);
dst[6] = _mm_unpacklo_epi64(src[3], src[7]);
dst[7] = _mm_unpackhi_epi64(src[3], src[7]);
return dst;
}

template <int tail>
inline std::array<__m128i, 8> load_fp32_fp16_tr_x8_word(const float* a, size_t lda) {
static_assert(tail > 0 && tail <= 8, "Unexpected tail value.");
std::array<__m128i, 8> dst;
for (int i = 0; i < tail; ++i) {
dst[i] = _mm256_cvtps_ph(_mm256_loadu_ps(a + i * lda), _MM_FROUND_TO_NEAREST_INT);
}
for (int i = tail; i < 8; ++i) dst[i] = _mm_setzero_si128();
return tr_x8_word(dst);
}
constexpr decltype(load_fp32_fp16_tr_x8_word<1>)* load_fp32_fp16_tr_x8_word_tbl[9]{
load_fp32_fp16_tr_x8_word<1>, load_fp32_fp16_tr_x8_word<1>, load_fp32_fp16_tr_x8_word<2>,
load_fp32_fp16_tr_x8_word<3>, load_fp32_fp16_tr_x8_word<4>, load_fp32_fp16_tr_x8_word<5>,
load_fp32_fp16_tr_x8_word<6>, load_fp32_fp16_tr_x8_word<7>, load_fp32_fp16_tr_x8_word<8>};

template <int tail>
inline std::array<__m128i, 8> load_maskz_fp32_fp16_tr_x8_word(const float* a, size_t lda, __m256i mask) {
static_assert(tail > 0 && tail <= 8, "Unexpected tail value.");
std::array<__m128i, 8> dst;
for (int i = 0; i < tail; ++i) {
dst[i] = _mm256_cvtps_ph(_mm256_maskload_ps(a + i * lda, mask), _MM_FROUND_TO_NEAREST_INT);
}
for (int i = tail; i < 8; ++i) dst[i] = _mm_setzero_si128();
return tr_x8_word(dst);
}
constexpr decltype(load_maskz_fp32_fp16_tr_x8_word<1>)* load_maskz_fp32_fp16_tr_x8_word_tbl[9]{
load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<2>,
load_maskz_fp32_fp16_tr_x8_word<3>, load_maskz_fp32_fp16_tr_x8_word<4>, load_maskz_fp32_fp16_tr_x8_word<5>,
load_maskz_fp32_fp16_tr_x8_word<6>, load_maskz_fp32_fp16_tr_x8_word<7>, load_maskz_fp32_fp16_tr_x8_word<8>};
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif

#ifdef __GNUC__
#pragma GCC pop_options
#else
Expand Down
22 changes: 22 additions & 0 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -2383,6 +2383,28 @@ static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, co
}
return BTLA_CODE::Success;
}

inline __m512 poly_scale_2nd_ps(const __m512 z, const __m512 f, const __m512 c0, const __m512 c1, const __m512 c2) {
const auto y = _mm512_fmadd_ps(_mm512_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2;
const auto exp = _mm512_scalef_ps(y, z);
return exp;
}

inline __m512 exp_ps_0_1(const __m512 x) {
static const auto c0 = _mm512_set1_ps(0.240226507f);
static const auto c1 = _mm512_set1_ps(0.452920674f);
static const auto c2 = _mm512_set1_ps(0.713483036f);
static const float v_log2e = std::log2(std::exp(1.f));
static const auto log2e = _mm512_set1_ps(v_log2e);
static const auto half = _mm512_set1_ps(.5f);

const auto x1 = _mm512_fmadd_ps(x, log2e, half); // auto x1 = x * log2e + _mm512_set1_ps(.5f);
const auto z = _mm512_floor_ps(x1);
const auto f = _mm512_sub_ps(x1, z); // auto f = x1 - z;

return poly_scale_2nd_ps(z, f, c0, c1, c2);
}

#ifdef __GNUC__
#pragma GCC pop_options
#else
Expand Down
11 changes: 11 additions & 0 deletions bestla/bestla/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,17 @@ static inline BTLA_CODE layernorm(const T* srcptr, const T* scaleptr, const T* b
}
return BTLA_CODE::Success;
}

inline float exp_ps_0_1(float x) {
static const float log2e = std::log2(std::exp(1.f));
static const float ln2 = std::log(2.f);
const float x1 = x * log2e + .5f;
const float z = std::floor(x1);
const float f = x1 - z;
constexpr std::array<float, 3> coeff{0.240226507f, 0.452920674f, 0.713483036f};
// same as a * std::pow(2, z) but more precise
return ldexpf(coeff[0] * f * f + coeff[1] * f + coeff[2], static_cast<int>(z));
}
} // namespace ref
} // namespace kernel
} // namespace bestla
10 changes: 6 additions & 4 deletions neural_speed/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

find_package(Threads REQUIRED)
file(GLOB layers_srcs "layers/*.cpp")
file(GLOB test_srcs "layers/*test*.cpp")
list(REMOVE_ITEM layers_srcs ${test_srcs})
set(sources ne_layers.c ${layers_srcs})

add_shareable_library_w_warning(ne_layers "${sources}")
Expand All @@ -37,27 +39,27 @@ endif()

if (NS_BUILD_TESTS)

function(add_test_target src)
function(add_test_target src) # ARGN: additional source
get_filename_component(test_target ${src} NAME_WE)
get_filename_component(src_dir ${src} DIRECTORY)
string(REGEX REPLACE [/\\] "_" src_dir ${src_dir})
if(src_dir)
set (test_target "${src_dir}_${test_target}")
endif()
set (test_target "test_${test_target}")
add_executable_w_warning(${test_target} ${src})
add_executable_w_warning(${test_target} ${src} ${ARGN})
target_compile_definitions(${test_target} PRIVATE NS_TESTS)
target_compile_options(${test_target} PRIVATE -fsanitize=address)
target_link_options(${test_target} PRIVATE -fsanitize=address)
target_include_directories(${test_target} PUBLIC .)
target_link_libraries(${test_target} PUBLIC Threads::Threads bestla::bestla ne_vec)
target_link_libraries(${test_target} PUBLIC Threads::Threads bestla ne_vec)
if(NOT WIN32)
target_link_libraries(${test_target} PUBLIC rt)
endif()
add_test(NAME ${test_target} COMMAND ${test_target})
set_tests_properties(${test_target} PROPERTIES LABELS "${src_dir}_test")
endfunction()

add_test_target(layers/mha_dense.cpp)
add_test_target(layers/mha_dense.cpp layers/mha_dense_tests.cpp)

endif()
Loading

0 comments on commit bc5ee16

Please sign in to comment.