Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit bc5ee16

Browse files
authored
Add Fused-Attention Layer for AVX2 Platforms (#137)
* add prefer_fp32 * add avx2 mha * TARGET_512 * follow clang-tidy * fix ut on linux * move intrin utils to btla * move btla templates to mha_dense_wrapper.h * add f16c target for avx2 intrin * move tests to mha_dense_tests.cpp * fix format
1 parent 150e752 commit bc5ee16

File tree

12 files changed

+3016
-2290
lines changed

12 files changed

+3016
-2290
lines changed

bestla/bestla/kernel_avx2.h

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace avx2 {
2424
#if CompileAVX2()
2525
#ifdef __GNUC__
2626
#pragma GCC push_options
27-
#pragma GCC target("avx2", "fma")
27+
#pragma GCC target("avx2", "fma", "f16c")
2828
#else
2929
#endif
3030

@@ -1118,6 +1118,95 @@ static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, co
11181118
return BTLA_CODE::Success;
11191119
}
11201120

1121+
inline __m256 poly_scale_2nd_ps(const __m256i z, const __m256 f, const __m256 c0, const __m256 c1, const __m256 c2) {
1122+
const auto y = _mm256_fmadd_ps(_mm256_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2;
1123+
static const auto mask_exp = _mm256_set1_epi32(0x7f800000);
1124+
static const auto mask_not_exp = _mm256_set1_epi32(~0x7f800000);
1125+
1126+
const auto y_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_exp);
1127+
const auto y_not_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_not_exp);
1128+
1129+
const auto y_exp_scaled = _mm256_add_epi32(y_exp, _mm256_slli_epi32(z, 23));
1130+
return _mm256_castsi256_ps(_mm256_or_si256(y_not_exp, _mm256_and_si256(y_exp_scaled, mask_exp)));
1131+
}
1132+
1133+
inline __m256 exp_ps_0_1(const __m256 x) {
1134+
static const auto c0 = _mm256_set1_ps(0.240226507f);
1135+
static const auto c1 = _mm256_set1_ps(0.452920674f);
1136+
static const auto c2 = _mm256_set1_ps(0.713483036f);
1137+
static const float v_log2e = std::log2(std::exp(1.f));
1138+
static const auto log2e = _mm256_set1_ps(v_log2e);
1139+
static const auto half = _mm256_set1_ps(.5f);
1140+
1141+
const auto x1 = _mm256_fmadd_ps(x, log2e, half); // auto x1 = x * log2e + _mm256_set1_ps(.5f);
1142+
const auto z = _mm256_floor_ps(x1);
1143+
const auto f = _mm256_sub_ps(x1, z); // auto f = x1 - z;
1144+
1145+
return poly_scale_2nd_ps(_mm256_cvtps_epi32(z), f, c0, c1, c2);
1146+
}
1147+
1148+
#ifdef __GNUC__
1149+
#pragma GCC diagnostic push
1150+
#pragma GCC diagnostic ignored "-Wignored-attributes" // https://stackoverflow.com/a/49216021
1151+
#endif
1152+
// Interleave 8 xmm vectors of words inplace
1153+
static inline std::array<__m128i, 8> tr_x8_word(std::array<__m128i, 8>& src) { // NOLINT [runtime/references]
1154+
std::array<__m128i, 8> dst;
1155+
1156+
for (int i = 0; i < 8; i += 2) {
1157+
dst[i + 0] = _mm_unpacklo_epi16(src[i + 0], src[i + 1]);
1158+
dst[i + 1] = _mm_unpackhi_epi16(src[i + 0], src[i + 1]);
1159+
}
1160+
for (int i = 0; i < 8; i += 4) {
1161+
src[i + 0] = _mm_unpacklo_epi32(dst[i + 0], dst[i + 2]);
1162+
src[i + 1] = _mm_unpackhi_epi32(dst[i + 0], dst[i + 2]);
1163+
src[i + 2] = _mm_unpacklo_epi32(dst[i + 1], dst[i + 3]);
1164+
src[i + 3] = _mm_unpackhi_epi32(dst[i + 1], dst[i + 3]);
1165+
}
1166+
dst[0] = _mm_unpacklo_epi64(src[0], src[4]);
1167+
dst[1] = _mm_unpackhi_epi64(src[0], src[4]);
1168+
dst[2] = _mm_unpacklo_epi64(src[1], src[5]);
1169+
dst[3] = _mm_unpackhi_epi64(src[1], src[5]);
1170+
dst[4] = _mm_unpacklo_epi64(src[2], src[6]);
1171+
dst[5] = _mm_unpackhi_epi64(src[2], src[6]);
1172+
dst[6] = _mm_unpacklo_epi64(src[3], src[7]);
1173+
dst[7] = _mm_unpackhi_epi64(src[3], src[7]);
1174+
return dst;
1175+
}
1176+
1177+
template <int tail>
1178+
inline std::array<__m128i, 8> load_fp32_fp16_tr_x8_word(const float* a, size_t lda) {
1179+
static_assert(tail > 0 && tail <= 8, "Unexpected tail value.");
1180+
std::array<__m128i, 8> dst;
1181+
for (int i = 0; i < tail; ++i) {
1182+
dst[i] = _mm256_cvtps_ph(_mm256_loadu_ps(a + i * lda), _MM_FROUND_TO_NEAREST_INT);
1183+
}
1184+
for (int i = tail; i < 8; ++i) dst[i] = _mm_setzero_si128();
1185+
return tr_x8_word(dst);
1186+
}
1187+
constexpr decltype(load_fp32_fp16_tr_x8_word<1>)* load_fp32_fp16_tr_x8_word_tbl[9]{
1188+
load_fp32_fp16_tr_x8_word<1>, load_fp32_fp16_tr_x8_word<1>, load_fp32_fp16_tr_x8_word<2>,
1189+
load_fp32_fp16_tr_x8_word<3>, load_fp32_fp16_tr_x8_word<4>, load_fp32_fp16_tr_x8_word<5>,
1190+
load_fp32_fp16_tr_x8_word<6>, load_fp32_fp16_tr_x8_word<7>, load_fp32_fp16_tr_x8_word<8>};
1191+
1192+
template <int tail>
1193+
inline std::array<__m128i, 8> load_maskz_fp32_fp16_tr_x8_word(const float* a, size_t lda, __m256i mask) {
1194+
static_assert(tail > 0 && tail <= 8, "Unexpected tail value.");
1195+
std::array<__m128i, 8> dst;
1196+
for (int i = 0; i < tail; ++i) {
1197+
dst[i] = _mm256_cvtps_ph(_mm256_maskload_ps(a + i * lda, mask), _MM_FROUND_TO_NEAREST_INT);
1198+
}
1199+
for (int i = tail; i < 8; ++i) dst[i] = _mm_setzero_si128();
1200+
return tr_x8_word(dst);
1201+
}
1202+
constexpr decltype(load_maskz_fp32_fp16_tr_x8_word<1>)* load_maskz_fp32_fp16_tr_x8_word_tbl[9]{
1203+
load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<2>,
1204+
load_maskz_fp32_fp16_tr_x8_word<3>, load_maskz_fp32_fp16_tr_x8_word<4>, load_maskz_fp32_fp16_tr_x8_word<5>,
1205+
load_maskz_fp32_fp16_tr_x8_word<6>, load_maskz_fp32_fp16_tr_x8_word<7>, load_maskz_fp32_fp16_tr_x8_word<8>};
1206+
#ifdef __GNUC__
1207+
#pragma GCC diagnostic pop
1208+
#endif
1209+
11211210
#ifdef __GNUC__
11221211
#pragma GCC pop_options
11231212
#else

bestla/bestla/kernel_avx512f.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,6 +2383,28 @@ static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, co
23832383
}
23842384
return BTLA_CODE::Success;
23852385
}
2386+
2387+
inline __m512 poly_scale_2nd_ps(const __m512 z, const __m512 f, const __m512 c0, const __m512 c1, const __m512 c2) {
2388+
const auto y = _mm512_fmadd_ps(_mm512_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2;
2389+
const auto exp = _mm512_scalef_ps(y, z);
2390+
return exp;
2391+
}
2392+
2393+
inline __m512 exp_ps_0_1(const __m512 x) {
2394+
static const auto c0 = _mm512_set1_ps(0.240226507f);
2395+
static const auto c1 = _mm512_set1_ps(0.452920674f);
2396+
static const auto c2 = _mm512_set1_ps(0.713483036f);
2397+
static const float v_log2e = std::log2(std::exp(1.f));
2398+
static const auto log2e = _mm512_set1_ps(v_log2e);
2399+
static const auto half = _mm512_set1_ps(.5f);
2400+
2401+
const auto x1 = _mm512_fmadd_ps(x, log2e, half); // auto x1 = x * log2e + _mm512_set1_ps(.5f);
2402+
const auto z = _mm512_floor_ps(x1);
2403+
const auto f = _mm512_sub_ps(x1, z); // auto f = x1 - z;
2404+
2405+
return poly_scale_2nd_ps(z, f, c0, c1, c2);
2406+
}
2407+
23862408
#ifdef __GNUC__
23872409
#pragma GCC pop_options
23882410
#else

bestla/bestla/kernel_ref.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,6 +1522,17 @@ static inline BTLA_CODE layernorm(const T* srcptr, const T* scaleptr, const T* b
15221522
}
15231523
return BTLA_CODE::Success;
15241524
}
1525+
1526+
inline float exp_ps_0_1(float x) {
1527+
static const float log2e = std::log2(std::exp(1.f));
1528+
static const float ln2 = std::log(2.f);
1529+
const float x1 = x * log2e + .5f;
1530+
const float z = std::floor(x1);
1531+
const float f = x1 - z;
1532+
constexpr std::array<float, 3> coeff{0.240226507f, 0.452920674f, 0.713483036f};
1533+
// same as a * std::pow(2, z) but more precise
1534+
return ldexpf(coeff[0] * f * f + coeff[1] * f + coeff[2], static_cast<int>(z));
1535+
}
15251536
} // namespace ref
15261537
} // namespace kernel
15271538
} // namespace bestla

neural_speed/core/CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
find_package(Threads REQUIRED)
1616
file(GLOB layers_srcs "layers/*.cpp")
17+
file(GLOB test_srcs "layers/*test*.cpp")
18+
list(REMOVE_ITEM layers_srcs ${test_srcs})
1719
set(sources ne_layers.c ${layers_srcs})
1820

1921
add_shareable_library_w_warning(ne_layers "${sources}")
@@ -37,27 +39,27 @@ endif()
3739

3840
if (NS_BUILD_TESTS)
3941

40-
function(add_test_target src)
42+
function(add_test_target src) # ARGN: additional source
4143
get_filename_component(test_target ${src} NAME_WE)
4244
get_filename_component(src_dir ${src} DIRECTORY)
4345
string(REGEX REPLACE [/\\] "_" src_dir ${src_dir})
4446
if(src_dir)
4547
set (test_target "${src_dir}_${test_target}")
4648
endif()
4749
set (test_target "test_${test_target}")
48-
add_executable_w_warning(${test_target} ${src})
50+
add_executable_w_warning(${test_target} ${src} ${ARGN})
4951
target_compile_definitions(${test_target} PRIVATE NS_TESTS)
5052
target_compile_options(${test_target} PRIVATE -fsanitize=address)
5153
target_link_options(${test_target} PRIVATE -fsanitize=address)
5254
target_include_directories(${test_target} PUBLIC .)
53-
target_link_libraries(${test_target} PUBLIC Threads::Threads bestla::bestla ne_vec)
55+
target_link_libraries(${test_target} PUBLIC Threads::Threads bestla ne_vec)
5456
if(NOT WIN32)
5557
target_link_libraries(${test_target} PUBLIC rt)
5658
endif()
5759
add_test(NAME ${test_target} COMMAND ${test_target})
5860
set_tests_properties(${test_target} PROPERTIES LABELS "${src_dir}_test")
5961
endfunction()
6062

61-
add_test_target(layers/mha_dense.cpp)
63+
add_test_target(layers/mha_dense.cpp layers/mha_dense_tests.cpp)
6264

6365
endif()

0 commit comments

Comments
 (0)