Skip to content

Commit 5714f9d

Browse files
authored
Fix LSE output error in FA2 kvsplit (#87)
Signed-off-by: griii <guo_rui@mail.ustc.edu.cn>
1 parent ee4d25b commit 5714f9d

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
765765
// q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
766766
int64_t lse_size_before[] = {num_heads, batch_size, max_seqlen_q};
767767
int64_t lse_size_after[] = {num_heads * max_seqlen_q, batch_size};
768-
softmax_lse = softmax_lse.reshape(lse_size_before).transpose(1, 2).reshape(lse_size_after);
768+
769+
770+
if (params.num_splits > 1){
771+
// When KV-split is enabled (num_splits > 1), LSE is first computed partially through lse_accum tensors. Then, an additional kernel, combine_attn_seqk_parallel, reduces these partials into the final LSE.
772+
// This kernel produces LSE in a [seqlen_q, h, b] layout which can be directly used as it is already in the canonical form.
773+
softmax_lse = softmax_lse.reshape(lse_size_after);
774+
}else{
775+
// The standard forward kernel produces LSE in a [b, h, seqlen_q] layout.
776+
// It must be transposed to the canonical [seqlen_q, h, b] layout.
777+
softmax_lse = softmax_lse.reshape(lse_size_before).transpose(1, 2).reshape(lse_size_after);
778+
}
779+
769780
}
770781

771782
return {out, softmax_lse};

0 commit comments

Comments
 (0)