Skip to content

Commit

Permalink
Revert num_splits in flash_bwd_kernel.h for large model (PaddlePaddle#21
Browse files Browse the repository at this point in the history
)
  • Loading branch information
AnnaTrainingG authored Oct 27, 2023
1 parent b74460b commit 0fa5933
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -1585,15 +1585,15 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
// The block index for the head.
const int bidh = blockIdx.z;
constexpr int kBlockN = Kernel_traits::kBlockN;
if (params.num_splits == 1) { // means grid.x = 1, blockIdx.x = 0;
int loop_step_x = 0;
for(int i = 0; i < params.seqlen_k; i+= kBlockN) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, loop_step_x);
loop_step_x += 1;
}
} else {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
//if (params.num_splits == 1) { // means grid.x = 1, blockIdx.x = 0;
// int loop_step_x = 0;
// for(int i = 0; i < params.seqlen_k; i+= kBlockN) {
// compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, loop_step_x);
// loop_step_x += 1;
// }
//} else {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
//}
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit 0fa5933

Please sign in to comment.