@@ -412,6 +412,9 @@ struct CollectiveMainloopFwdSm90 {
412412        int  const * const  leftpad_k = nullptr ;
413413        int  const * const  seqlens_rotary = nullptr ;
414414        ElementSAux const * const  ptr_S_aux = nullptr ;
415+         //  Context parallelism (CP) parameters
416+         int  const  cp_world_size = 1 ;
417+         int  const  cp_rank = 0 ;
415418    };
416419
417420    //  Device side kernel params
@@ -469,6 +472,8 @@ struct CollectiveMainloopFwdSm90 {
469472        int  const * const  leftpad_k = nullptr ;
470473        int  const * const  seqlens_rotary = nullptr ;
471474        ElementSAux const * const  ptr_S_aux = nullptr ;
475+         int  cp_world_size = 1 ;
476+         int  cp_rank = 0 ;
472477    };
473478
474479    static  Params
@@ -540,7 +545,7 @@ struct CollectiveMainloopFwdSm90 {
540545                return  nullptr ;
541546            }
542547        }();
543-          
548+ 
544549        auto  const  shape_Qv_packed = cute::conditional_return<!PackGQA>(
545550            shape_Qv,
546551            make_shape (make_shape (qhead_per_khead, get<0 >(shape_Qv)), get<1 >(shape_Qv), get<2 >(args.shape_K ), get<3 >(shape_Qv))
@@ -584,7 +589,8 @@ struct CollectiveMainloopFwdSm90 {
584589                args.kv_batch_idx ,
585590                args.cu_seqlens_q , args.cu_seqlens_k , args.cu_seqlens_k_new ,
586591                args.seqused_q , args.seqused_k , args.leftpad_k , args.seqlens_rotary ,
587-                 args.ptr_S_aux };
592+                 args.ptr_S_aux ,
593+                 args.cp_world_size , args.cp_rank };
588594    }
589595
590596    // / Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -999,6 +1005,7 @@ struct CollectiveMainloopFwdSm90 {
9991005        static  constexpr  int  kBlockN  = get<1 >(TileShape_MNK{});
10001006
10011007        //  can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda
1008+         //  block index
10021009        int  const  m_block = get<0 >(block_coord);
10031010        int  const  bidh = get<1 >(block_coord);
10041011        int  const  bidb = get<2 >(block_coord);
@@ -1093,7 +1100,8 @@ struct CollectiveMainloopFwdSm90 {
10931100        //  But we subtract n_offset for consistency in mask calculations
10941101        flash::Mask<kBlockM , kBlockN , PackGQA, TiledMmaQK> mask (
10951102            thread_idx, seqlen_q, seqlen_k, params.window_size_left , params.window_size_right , 0  - n_offset /* sink_token_length*/ 
1096-             params.qhead_per_khead_divmod 
1103+             params.qhead_per_khead_divmod ,
1104+             params.cp_world_size , params.cp_rank 
10971105        );
10981106
10991107        float  softcap_val = params.softcap_val ;
@@ -1201,6 +1209,7 @@ struct CollectiveMainloopFwdSm90 {
12011209        }
12021210
12031211        if  constexpr  (IntraWGOverlap) {
1212+ 
12041213            Tensor tSrS = partition_fragment_C (tiled_mma_qk, select<0 , 1 >(TileShape_MNK{}));
12051214            consumer_wait (pipeline_k, smem_pipe_read);
12061215            flash::gemm</* zero_init=*/ true , /* wg_wait=*/ 1 >(tiled_mma_qk, tSrQ, tSrK (_, _, _, smem_pipe_read.index ()), tSrS);
@@ -1272,7 +1281,8 @@ struct CollectiveMainloopFwdSm90 {
12721281            };
12731282
12741283            if  constexpr  (Is_causal || Is_local) { //  Separate iterations with causal or local masking
1275-                 auto  mask_fn = [&](auto & tSrS, int  n_block) { mask.template  apply <false  /* Seqlenk_mask*/ 
1284+                 auto  mask_fn = [&](auto & tSrS, int  n_block) {
1285+                   mask.template  apply <false  /* Seqlenk_mask*/ 
12761286                int  const  m_idx_min = !PackGQA ? m_block * kBlockM  : params.qhead_per_khead_divmod .divide (m_block * kBlockM );
12771287                //  If local, blocking (window_size_right + window_size_left)
12781288                int  const  n_block_min_causal_local_mask =
@@ -1288,7 +1298,7 @@ struct CollectiveMainloopFwdSm90 {
12881298            int  const  n_block_min_before_local_mask = !Is_local
12891299                ? n_block_min
12901300                : std::max (n_block_min,
1291-                            cute::ceil_div (m_idx_max + seqlen_k - seqlen_q - params.window_size_left , kBlockN ));
1301+                            cute::ceil_div (m_idx_max + params. cp_world_size  *  seqlen_k - seqlen_q - params.window_size_left , kBlockN ));
12921302            auto  no_mask_fn = [](auto & tSrS, int  n_block) { };
12931303            #pragma  unroll 1
12941304            for  (; n_block >= n_block_min_before_local_mask; --n_block) {
@@ -1414,7 +1424,7 @@ struct CollectiveMainloopFwdSm90 {
14141424            //  Tensor scores_scale = softmax.finalize(v_descale);
14151425            Tensor scores_scale = make_tensor_like (softmax.row_max );
14161426            finalize_dispatch (scores_scale, v_descale);
1417-              
1427+ 
14181428            if  constexpr  (LargeHeadDimV) {
14191429                cutlass::arch::NamedBarrier::sync (NumMmaThreads, static_cast <uint32_t >(FwdNamedBarriers::PEmpty) /* id*/ 
14201430                store_scales (scores_scale, smem_pipe_read.index ());
0 commit comments