Skip to content

Commit

Permalink
[Cherry-pick] Support flash attention 2 with causal masking when KV's…
Browse files Browse the repository at this point in the history
… seq length is longer than Q's seq length. (PaddlePaddle#36)

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
  • Loading branch information
Wanglongzhi2001 and BoxiangW authored Apr 8, 2024
1 parent 4b554d0 commit 86e9188
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
8 changes: 5 additions & 3 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded;

int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM;
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN - int(binfo.actual_seqlen_k - binfo.actual_seqlen_q)) / kBlockM;
m_block_min = m_block_min < 0 ? 0 : m_block_min;

// We might need to exit early and write 0 to dK and dV.
// Otherwise we get wrong result for the case where we don't enter the for loop.
Expand Down Expand Up @@ -873,7 +874,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
} else if (m_block * kBlockM < (n_block + 1) * kBlockN
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
}
Expand Down Expand Up @@ -1424,7 +1426,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
// the corresponding values of K would be 0, so the result would still be correct.
if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
}
Expand Down
11 changes: 6 additions & 5 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) {
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
n_block_max = std::min(n_block_max, cute::ceil_div(
(m_block + 1) * kBlockM + int(binfo.actual_seqlen_k - binfo.actual_seqlen_q), kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
Expand Down Expand Up @@ -429,10 +430,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
gSparseMask.data() = gSparseMask.data() + (-kBlockN);
} else {
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16);
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_q, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
}
Expand Down
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t

template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
const uint32_t warp_row_stride) {
const uint32_t max_seqlen_q, const uint32_t max_seqlen_k,
const uint32_t row_idx_offset_, const uint32_t warp_row_stride) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32;
Expand All @@ -156,7 +156,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const uint32_t row_idx = row_idx_base + i * 8;
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8;
Expand Down

0 comments on commit 86e9188

Please sign in to comment.