@@ -547,7 +547,7 @@ struct CollectiveMainloopFwdSm90 {
547547 return nullptr ;
548548 }
549549 }();
550-
550+
551551 auto const shape_Qv_packed = cute::conditional_return<!PackGQA>(
552552 shape_Qv,
553553 make_shape (make_shape (qhead_per_khead, get<0 >(shape_Qv)), get<1 >(shape_Qv), get<2 >(args.shape_K ), get<3 >(shape_Qv))
@@ -1007,7 +1007,6 @@ struct CollectiveMainloopFwdSm90 {
10071007 static constexpr int kBlockN = get<1 >(TileShape_MNK{});
10081008
10091009 // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda
1010- // block index
10111010 int const m_block = get<0 >(block_coord);
10121011 int const bidh = get<1 >(block_coord);
10131012 int const bidb = get<2 >(block_coord);
@@ -1103,7 +1102,7 @@ struct CollectiveMainloopFwdSm90 {
11031102 flash::Mask<kBlockM , kBlockN , PackGQA, TiledMmaQK> mask (
11041103 thread_idx, seqlen_q, seqlen_k, params.window_size_left , params.window_size_right , 0 - n_offset /* sink_token_length*/ ,
11051104 params.qhead_per_khead_divmod ,
1106- params.cp_world_size , params.cp_rank , seqlen_info.cp_tot_seqlen_k
1105+ params.cp_world_size , params.cp_rank , seqlen_info.tot_seqlen_k
11071106 );
11081107
11091108 float softcap_val = params.softcap_val ;
@@ -1211,7 +1210,6 @@ struct CollectiveMainloopFwdSm90 {
12111210 }
12121211
12131212 if constexpr (IntraWGOverlap) {
1214-
12151213 Tensor tSrS = partition_fragment_C (tiled_mma_qk, select<0 , 1 >(TileShape_MNK{}));
12161214 consumer_wait (pipeline_k, smem_pipe_read);
12171215 flash::gemm</* zero_init=*/ true , /* wg_wait=*/ -1 >(tiled_mma_qk, tSrQ, tSrK (_, _, _, smem_pipe_read.index ()), tSrS);
@@ -1283,8 +1281,7 @@ struct CollectiveMainloopFwdSm90 {
12831281 };
12841282
12851283 if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking
1286- auto mask_fn = [&](auto & tSrS, int n_block) {
1287- mask.template apply <false /* Seqlenk_mask*/ , Is_causal, Is_local>(tSrS, m_block, n_block); };
1284+ auto mask_fn = [&](auto & tSrS, int n_block) { mask.template apply <false /* Seqlenk_mask*/ , Is_causal, Is_local>(tSrS, m_block, n_block); };
12881285 int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod .divide (m_block * kBlockM );
12891286 // If local, blocking (window_size_right + window_size_left)
12901287 int const n_block_min_causal_local_mask =
@@ -1297,13 +1294,15 @@ struct CollectiveMainloopFwdSm90 {
12971294
12981295 int const m_idx_max = !PackGQA ? (m_block + 1 ) * kBlockM : params.qhead_per_khead_divmod .divide ((m_block + 1 ) * kBlockM - 1 ) + 1 ;
12991296 // If local, blocking (m_idx_max - m_idx_min)
1297+ // when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1.
1298+ // cp_world_size is guaranteed to be greater than 0
13001299 int const n_block_min_before_local_mask = !Is_local
13011300 ? n_block_min
13021301 : std::max (n_block_min,
1303- cute::ceil_div (m_idx_max +
1304- params.cp_world_size * seqlen_k -
1305- seqlen_q - params. window_size_left ,
1306- params. cp_world_size * kBlockN ));
1302+ cute::ceil_div (
1303+ cute::ceil_div (m_idx_max + seqlen_info. tot_seqlen_k - seqlen_q - params.window_size_left - seqlen_info. cp_rank ,
1304+ seqlen_info. cp_world_size ) ,
1305+ kBlockN ));
13071306 auto no_mask_fn = [](auto & tSrS, int n_block) { };
13081307 #pragma unroll 1
13091308 for (; n_block >= n_block_min_before_local_mask; --n_block) {
@@ -1429,7 +1428,7 @@ struct CollectiveMainloopFwdSm90 {
14291428 // Tensor scores_scale = softmax.finalize(v_descale);
14301429 Tensor scores_scale = make_tensor_like (softmax.row_max );
14311430 finalize_dispatch (scores_scale, v_descale);
1432-
1431+
14331432 if constexpr (LargeHeadDimV) {
14341433 cutlass::arch::NamedBarrier::sync (NumMmaThreads, static_cast <uint32_t >(FwdNamedBarriers::PEmpty) /* id*/ );
14351434 store_scales (scores_scale, smem_pipe_read.index ());
0 commit comments