From 5fc132ac11e78d26471ca09e5ba0cd817c3424d8 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 11 Jan 2024 20:48:03 +0800 Subject: [PATCH] fix Is_equal_qk error (#34) --- csrc/flash_attn/src/flash_fwd_launch_template.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 6c638261766535..ce162ee1b12d49 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -35,7 +35,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool return_softmax = params.p_ptr != nullptr; const bool is_attn_mask = params.attn_mask_ptr != nullptr; - const bool is_equal_qk = (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask); + const bool is_equal_qk = (params.cu_seqlens_q == nullptr) && (params.cu_seqlens_k == nullptr) && (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask); BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {