diff --git a/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp b/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp index 5e383a0dc..3393e4619 100644 --- a/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp +++ b/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp @@ -191,25 +191,25 @@ inline Vectorized exp_u20(Vectorized data) { // 1) out = exp(a - val) // 2) val = sum(out) -template +template inline void _exp_reduce_sum_fusion_kernel( - scalar_t* a, + T1* a, const int& size, - scalar_t* out, - scalar_t& val) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_max = at::vec::Vectorized(val); - scalar_t tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + T2* out, + T1& val) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + T1 tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp0 = at::vec::Vectorized::loadu(a + i); auto tmp1 = tmp0 - vec_max; auto tmp2 = exp_u20(tmp1); vec_tmp_sum += tmp2; at::native::_store(out + i, tmp2); } - tmp_sum = at::vec::vec_reduce_all( - [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + tmp_sum = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { return x + y; }, vec_tmp_sum); @@ -223,27 +223,6 @@ inline void _exp_reduce_sum_fusion_kernel( val = tmp_sum; } -// out = a / sum -template -inline void _normalization_kernel( - const T1* a, - const T1& sum, - const int& size, - T2* out) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_sum = at::vec::Vectorized(sum); - for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = tmp0 / vec_sum; - at::native::_store(out + i, tmp1); - } - for (long i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = tmp0 / sum; - out[i] = tmp1; - } -} - // 1) out = a * scale // 2) max = max(out) template @@ -767,7 +746,8 @@ void cpu_flash_attention( _exp_reduce_sum_fusion_kernel( qk_data + row * kvBlockSize, kvBlockSize, - qk_data + row * kvBlockSize, + conditional_data_ptr(qk_data, qk_reduced_data) + + row * kvBlockSize, tmp_sum); // exp_tmp <- exp(max[row] - max) exp_tmp = std::exp(qk_max_data[row] - tmp_max); @@ -775,21 +755,10 @@ void cpu_flash_attention( qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row]; // max[row] <- max qk_max_data[row] = tmp_max; - // qk <- qk / sum[row] - accum_t sum_new = qk_sum_data[row]; - _normalization_kernel( - qk_data + row * kvBlockSize, - sum_new, - kvBlockSize, - conditional_data_ptr(qk_data, qk_reduced_data) + - row * kvBlockSize); - // dst <- dst * sum_old / sum_new * exp_tmp + // dst <- dst * exp_tmp if (n > 0) { - accum_t sum_cor = sum_old / sum_new; at::vec::map( - [sum_cor, exp_tmp](Vec x) { - return x * Vec(sum_cor) * Vec(exp_tmp); - }, + [exp_tmp](Vec x) { return x * Vec(exp_tmp); }, dst_data + row * headSize, dst_data + row * headSize, headSize); @@ -856,10 +825,12 @@ void cpu_flash_attention( headSize); } } + // dst <- dst / sum[row] // reorder MHA output with strides for (int64_t row = 0; row < qBlockSize; ++row) { + accum_t sum_reciprocal = 1 / qk_sum_data[row]; at::vec::map( - [](Vec x) { return x; }, + [sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); }, out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM, dst_data + row * headSize,