Skip to content

Commit

Permalink
Optimize softmax via flash attention v2 (#2468)
Browse files Browse the repository at this point in the history
* optimize softmax as flash attention v2
  • Loading branch information
Valentine233 authored Jan 23, 2024
1 parent e046f5c commit f88a7d1
Showing 1 changed file with 18 additions and 47 deletions.
65 changes: 18 additions & 47 deletions csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,25 +191,25 @@ inline Vectorized<float> exp_u20(Vectorized<float> data) {

// 1) out = exp(a - val)
// 2) val = sum(out)
template <typename scalar_t>
template <typename T1, typename T2>
inline void _exp_reduce_sum_fusion_kernel(
scalar_t* a,
T1* a,
const int& size,
scalar_t* out,
scalar_t& val) {
auto vec_size = at::vec::Vectorized<scalar_t>::size();
auto vec_max = at::vec::Vectorized<scalar_t>(val);
scalar_t tmp_sum = 0;
auto vec_tmp_sum = at::vec::Vectorized<scalar_t>(tmp_sum);
T2* out,
T1& val) {
auto vec_size = at::vec::Vectorized<T1>::size();
auto vec_max = at::vec::Vectorized<T1>(val);
T1 tmp_sum = 0;
auto vec_tmp_sum = at::vec::Vectorized<T1>(tmp_sum);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = at::vec::Vectorized<scalar_t>::loadu(a + i);
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
auto tmp1 = tmp0 - vec_max;
auto tmp2 = exp_u20(tmp1);
vec_tmp_sum += tmp2;
at::native::_store(out + i, tmp2);
}
tmp_sum = at::vec::vec_reduce_all<scalar_t>(
[](at::vec::Vectorized<scalar_t>& x, at::vec::Vectorized<scalar_t>& y) {
tmp_sum = at::vec::vec_reduce_all<T1>(
[](at::vec::Vectorized<T1>& x, at::vec::Vectorized<T1>& y) {
return x + y;
},
vec_tmp_sum);
Expand All @@ -223,27 +223,6 @@ inline void _exp_reduce_sum_fusion_kernel(
val = tmp_sum;
}

// out = a / sum
template <typename T1, typename T2>
inline void _normalization_kernel(
const T1* a,
const T1& sum,
const int& size,
T2* out) {
auto vec_size = at::vec::Vectorized<T1>::size();
auto vec_sum = at::vec::Vectorized<T1>(sum);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
auto tmp1 = tmp0 / vec_sum;
at::native::_store(out + i, tmp1);
}
for (long i = vec_size * (size / vec_size); i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = tmp0 / sum;
out[i] = tmp1;
}
}

// 1) out = a * scale
// 2) max = max(out)
template <typename scalar_t>
Expand Down Expand Up @@ -767,29 +746,19 @@ void cpu_flash_attention(
_exp_reduce_sum_fusion_kernel(
qk_data + row * kvBlockSize,
kvBlockSize,
qk_data + row * kvBlockSize,
conditional_data_ptr(qk_data, qk_reduced_data) +
row * kvBlockSize,
tmp_sum);
// exp_tmp <- exp(max[row] - max)
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
// sum[row] <- sum + exp_tmp * sum[row]
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
// max[row] <- max
qk_max_data[row] = tmp_max;
// qk <- qk / sum[row]
accum_t sum_new = qk_sum_data[row];
_normalization_kernel(
qk_data + row * kvBlockSize,
sum_new,
kvBlockSize,
conditional_data_ptr(qk_data, qk_reduced_data) +
row * kvBlockSize);
// dst <- dst * sum_old / sum_new * exp_tmp
// dst <- dst * exp_tmp
if (n > 0) {
accum_t sum_cor = sum_old / sum_new;
at::vec::map<accum_t>(
[sum_cor, exp_tmp](Vec x) {
return x * Vec(sum_cor) * Vec(exp_tmp);
},
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
dst_data + row * headSize,
dst_data + row * headSize,
headSize);
Expand Down Expand Up @@ -856,10 +825,12 @@ void cpu_flash_attention(
headSize);
}
}
// dst <- dst / sum[row]
// reorder MHA output with strides
for (int64_t row = 0; row < qBlockSize; ++row) {
accum_t sum_reciprocal = 1 / qk_sum_data[row];
at::vec::map<scalar_t>(
[](Vec x) { return x; },
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
out_data + i * oStrideB + j * oStrideH + m * oStrideM +
row * oStrideM,
dst_data + row * headSize,
Expand Down

0 comments on commit f88a7d1

Please sign in to comment.