@@ -24,7 +24,7 @@ namespace avx2 {
24
24
#if CompileAVX2()
25
25
#ifdef __GNUC__
26
26
#pragma GCC push_options
27
- #pragma GCC target("avx2", "fma")
27
+ #pragma GCC target("avx2", "fma", "f16c" )
28
28
#else
29
29
#endif
30
30
@@ -1118,6 +1118,95 @@ static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, co
1118
1118
return BTLA_CODE::Success;
1119
1119
}
1120
1120
1121
+ inline __m256 poly_scale_2nd_ps (const __m256i z, const __m256 f, const __m256 c0, const __m256 c1, const __m256 c2) {
1122
+ const auto y = _mm256_fmadd_ps (_mm256_fmadd_ps (f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2;
1123
+ static const auto mask_exp = _mm256_set1_epi32 (0x7f800000 );
1124
+ static const auto mask_not_exp = _mm256_set1_epi32 (~0x7f800000 );
1125
+
1126
+ const auto y_exp = _mm256_and_si256 (_mm256_castps_si256 (y), mask_exp);
1127
+ const auto y_not_exp = _mm256_and_si256 (_mm256_castps_si256 (y), mask_not_exp);
1128
+
1129
+ const auto y_exp_scaled = _mm256_add_epi32 (y_exp, _mm256_slli_epi32 (z, 23 ));
1130
+ return _mm256_castsi256_ps (_mm256_or_si256 (y_not_exp, _mm256_and_si256 (y_exp_scaled, mask_exp)));
1131
+ }
1132
+
1133
+ inline __m256 exp_ps_0_1 (const __m256 x) {
1134
+ static const auto c0 = _mm256_set1_ps (0 .240226507f );
1135
+ static const auto c1 = _mm256_set1_ps (0 .452920674f );
1136
+ static const auto c2 = _mm256_set1_ps (0 .713483036f );
1137
+ static const float v_log2e = std::log2 (std::exp (1 .f ));
1138
+ static const auto log2e = _mm256_set1_ps (v_log2e);
1139
+ static const auto half = _mm256_set1_ps (.5f );
1140
+
1141
+ const auto x1 = _mm256_fmadd_ps (x, log2e, half); // auto x1 = x * log2e + _mm256_set1_ps(.5f);
1142
+ const auto z = _mm256_floor_ps (x1);
1143
+ const auto f = _mm256_sub_ps (x1, z); // auto f = x1 - z;
1144
+
1145
+ return poly_scale_2nd_ps (_mm256_cvtps_epi32 (z), f, c0, c1, c2);
1146
+ }
1147
+
1148
+ #ifdef __GNUC__
1149
+ #pragma GCC diagnostic push
1150
+ #pragma GCC diagnostic ignored "-Wignored-attributes" // https://stackoverflow.com/a/49216021
1151
+ #endif
1152
+ // Interleave 8 xmm vectors of words inplace
1153
+ static inline std::array<__m128i, 8 > tr_x8_word (std::array<__m128i, 8 >& src) { // NOLINT [runtime/references]
1154
+ std::array<__m128i, 8 > dst;
1155
+
1156
+ for (int i = 0 ; i < 8 ; i += 2 ) {
1157
+ dst[i + 0 ] = _mm_unpacklo_epi16 (src[i + 0 ], src[i + 1 ]);
1158
+ dst[i + 1 ] = _mm_unpackhi_epi16 (src[i + 0 ], src[i + 1 ]);
1159
+ }
1160
+ for (int i = 0 ; i < 8 ; i += 4 ) {
1161
+ src[i + 0 ] = _mm_unpacklo_epi32 (dst[i + 0 ], dst[i + 2 ]);
1162
+ src[i + 1 ] = _mm_unpackhi_epi32 (dst[i + 0 ], dst[i + 2 ]);
1163
+ src[i + 2 ] = _mm_unpacklo_epi32 (dst[i + 1 ], dst[i + 3 ]);
1164
+ src[i + 3 ] = _mm_unpackhi_epi32 (dst[i + 1 ], dst[i + 3 ]);
1165
+ }
1166
+ dst[0 ] = _mm_unpacklo_epi64 (src[0 ], src[4 ]);
1167
+ dst[1 ] = _mm_unpackhi_epi64 (src[0 ], src[4 ]);
1168
+ dst[2 ] = _mm_unpacklo_epi64 (src[1 ], src[5 ]);
1169
+ dst[3 ] = _mm_unpackhi_epi64 (src[1 ], src[5 ]);
1170
+ dst[4 ] = _mm_unpacklo_epi64 (src[2 ], src[6 ]);
1171
+ dst[5 ] = _mm_unpackhi_epi64 (src[2 ], src[6 ]);
1172
+ dst[6 ] = _mm_unpacklo_epi64 (src[3 ], src[7 ]);
1173
+ dst[7 ] = _mm_unpackhi_epi64 (src[3 ], src[7 ]);
1174
+ return dst;
1175
+ }
1176
+
1177
+ template <int tail>
1178
+ inline std::array<__m128i, 8 > load_fp32_fp16_tr_x8_word (const float * a, size_t lda) {
1179
+ static_assert (tail > 0 && tail <= 8 , " Unexpected tail value." );
1180
+ std::array<__m128i, 8 > dst;
1181
+ for (int i = 0 ; i < tail; ++i) {
1182
+ dst[i] = _mm256_cvtps_ph (_mm256_loadu_ps (a + i * lda), _MM_FROUND_TO_NEAREST_INT);
1183
+ }
1184
+ for (int i = tail; i < 8 ; ++i) dst[i] = _mm_setzero_si128 ();
1185
+ return tr_x8_word (dst);
1186
+ }
1187
+ constexpr decltype (load_fp32_fp16_tr_x8_word<1 >)* load_fp32_fp16_tr_x8_word_tbl[9]{
1188
+ load_fp32_fp16_tr_x8_word<1 >, load_fp32_fp16_tr_x8_word<1 >, load_fp32_fp16_tr_x8_word<2 >,
1189
+ load_fp32_fp16_tr_x8_word<3 >, load_fp32_fp16_tr_x8_word<4 >, load_fp32_fp16_tr_x8_word<5 >,
1190
+ load_fp32_fp16_tr_x8_word<6 >, load_fp32_fp16_tr_x8_word<7 >, load_fp32_fp16_tr_x8_word<8 >};
1191
+
1192
+ template <int tail>
1193
+ inline std::array<__m128i, 8 > load_maskz_fp32_fp16_tr_x8_word (const float * a, size_t lda, __m256i mask) {
1194
+ static_assert (tail > 0 && tail <= 8 , " Unexpected tail value." );
1195
+ std::array<__m128i, 8 > dst;
1196
+ for (int i = 0 ; i < tail; ++i) {
1197
+ dst[i] = _mm256_cvtps_ph (_mm256_maskload_ps (a + i * lda, mask), _MM_FROUND_TO_NEAREST_INT);
1198
+ }
1199
+ for (int i = tail; i < 8 ; ++i) dst[i] = _mm_setzero_si128 ();
1200
+ return tr_x8_word (dst);
1201
+ }
1202
+ constexpr decltype (load_maskz_fp32_fp16_tr_x8_word<1 >)* load_maskz_fp32_fp16_tr_x8_word_tbl[9]{
1203
+ load_maskz_fp32_fp16_tr_x8_word<1 >, load_maskz_fp32_fp16_tr_x8_word<1 >, load_maskz_fp32_fp16_tr_x8_word<2 >,
1204
+ load_maskz_fp32_fp16_tr_x8_word<3 >, load_maskz_fp32_fp16_tr_x8_word<4 >, load_maskz_fp32_fp16_tr_x8_word<5 >,
1205
+ load_maskz_fp32_fp16_tr_x8_word<6 >, load_maskz_fp32_fp16_tr_x8_word<7 >, load_maskz_fp32_fp16_tr_x8_word<8 >};
1206
+ #ifdef __GNUC__
1207
+ #pragma GCC diagnostic pop
1208
+ #endif
1209
+
1121
1210
#ifdef __GNUC__
1122
1211
#pragma GCC pop_options
1123
1212
#else
0 commit comments