From b7ffdbf6d1b9e3ad26f56b1e6637b5c2f091fb7b Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 8 Jul 2024 18:45:50 +0000 Subject: [PATCH 1/2] update to lastest ck --- csrc/composable_kernel | 2 +- csrc/flash_attn_ck/mha_fwd.cpp | 10 ++++++++++ csrc/flash_attn_ck/mha_varlen_fwd.cpp | 10 ++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index fa129c1a5..8182976c3 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit fa129c1a5db62354c4b39857d2b1598bb618f8ce +Subproject commit 8182976c37433808b5e3a27a6536d1b74b0c23a1 diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index e9bae19e5..8d1fce0f8 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -96,6 +96,8 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, v.data_ptr(), alibi_slopes_ptr, // bias has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + nullptr, // lse_acc + nullptr, // o_acc has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), nullptr, // seqstart_q @@ -109,6 +111,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, d, // hdim_v h, // nhead h_k, // nhead_k + 1, // num_splits softmax_scale, // scale_s 1, // scale_p 1, // scale_o @@ -117,6 +120,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, stride_v, stride_alibi_slopes, stride_randval, + 0, // stride_o_acc, stride_o, nhead_stride_q, nhead_stride_k, @@ -124,6 +128,8 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, 0, // nhead_stride_bias, FA without bias nhead_stride_randval, nhead_stride_lse, + 0, // nhead_stride_lse_acc + 0, // nhead_stride_o_acc nhead_stride_o, batch_stride_q, batch_stride_k, @@ -131,7 +137,11 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, 0, // batch_stride_bias, FA without bias batch_stride_randval, batch_stride_lse, + 0, // batch_stride_lse_acc + 0, // batch_stride_o_acc batch_stride_o, + 0, // split_stride_lse_acc + 0, // split_stride_o_acc mask.left, mask.right, static_cast(mask.type), diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 712e4e577..cab0dd942 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -100,6 +100,8 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, v.data_ptr(), alibi_slopes_ptr, // bias has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + nullptr, // lse_acc + nullptr, // o_acc has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), seqlens_q.data_ptr(), // seqstart_q @@ -113,6 +115,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, d, // hdim_v h, // nhead h_k, // nhead_k + 1, // num_splits softmax_scale, // scale_s 1, // scale_p 1, // scale_o @@ -121,6 +124,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, stride_v, stride_alibi_slopes, stride_randval, + 0, // stride_o_acc, stride_o, nhead_stride_q, nhead_stride_k, @@ -128,6 +132,8 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, 0, // nhead_stride_bias, FA without bias nhead_stride_randval, nhead_stride_lse, + 0, // nhead_stride_lse_acc + 0, // nhead_stride_o_acc nhead_stride_o, batch_stride_q, batch_stride_k, @@ -135,7 +141,11 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, 0, // batch_stride_bias, FA without bias batch_stride_randval, batch_stride_lse, + 0, // batch_stride_lse_acc + 0, // batch_stride_o_acc batch_stride_o, + 0, // split_stride_lse_acc + 0, // split_stride_o_acc mask.left, mask.right, static_cast(mask.type), From 3b62d4807fcba158e9ebaf2926ab11aee53be820 Mon Sep 17 00:00:00 2001 From: rocking Date: Wed, 10 Jul 2024 06:57:00 +0000 Subject: [PATCH 2/2] Add necessary compile flag --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 511d19520..391eb9254 100644 --- a/setup.py +++ b/setup.py @@ -317,6 +317,7 @@ def validate_and_update_archs(archs): "nvcc": [ "-O3","-std=c++17", + "-mllvm", "-enable-post-misched=0", "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", "-fgpu-flush-denormals-to-zero", "-DCK_ENABLE_BF16",