Skip to content

Commit

Permalink
Merge pull request #67 from ROCm/update-ck
Browse files Browse the repository at this point in the history
[WIP] update to latest ck
  • Loading branch information
rocking5566 authored Jul 10, 2024
2 parents c0f637a + 3b62d48 commit 23a2b1c
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion csrc/composable_kernel
Submodule composable_kernel updated 237 files
10 changes: 10 additions & 0 deletions csrc/flash_attn_ck/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
v.data_ptr(),
alibi_slopes_ptr, // bias
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
nullptr, // lse_acc
nullptr, // o_acc
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
nullptr, // seqstart_q
Expand All @@ -109,6 +111,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
d, // hdim_v
h, // nhead
h_k, // nhead_k
1, // num_splits
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
Expand All @@ -117,21 +120,28 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
stride_v,
stride_alibi_slopes,
stride_randval,
0, // stride_o_acc,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_randval,
nhead_stride_lse,
0, // nhead_stride_lse_acc
0, // nhead_stride_o_acc
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias, FA without bias
batch_stride_randval,
batch_stride_lse,
0, // batch_stride_lse_acc
0, // batch_stride_o_acc
batch_stride_o,
0, // split_stride_lse_acc
0, // split_stride_o_acc
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
Expand Down
10 changes: 10 additions & 0 deletions csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
v.data_ptr(),
alibi_slopes_ptr, // bias
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
nullptr, // lse_acc
nullptr, // o_acc
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
seqlens_q.data_ptr(), // seqstart_q
Expand All @@ -113,6 +115,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
d, // hdim_v
h, // nhead
h_k, // nhead_k
1, // num_splits
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
Expand All @@ -121,21 +124,28 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
stride_v,
stride_alibi_slopes,
stride_randval,
0, // stride_o_acc,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_randval,
nhead_stride_lse,
0, // nhead_stride_lse_acc
0, // nhead_stride_o_acc
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias, FA without bias
batch_stride_randval,
batch_stride_lse,
0, // batch_stride_lse_acc
0, // batch_stride_o_acc
batch_stride_o,
0, // split_stride_lse_acc
0, // split_stride_o_acc
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def validate_and_update_archs(archs):
"nvcc":
[
"-O3","-std=c++17",
"-mllvm", "-enable-post-misched=0",
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
"-DCK_ENABLE_BF16",
Expand Down

0 comments on commit 23a2b1c

Please sign in to comment.