@@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
580580 for (; tile_scheduler.is_valid (); ++tile_scheduler) {
581581 auto blk_coord = tile_scheduler.get_block_coord ();
582582 auto problem_shape = params.problem_shape ;
583- auto local_split_kv = params.split_kv ;
583+ auto local_split_kv = params.split_kv ;
584584 if (params.mainloop .ptr_seq != nullptr ) {
585585 get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
586- if (params.ptr_split_kv != nullptr ) {
586+ if (params.ptr_split_kv != nullptr ) {
587587 local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
588588 }
589589 }
590- if (local_split_kv <= get<3 >(blk_coord))
591- continue ;
590+ if (local_split_kv <= get<3 >(blk_coord))
591+ continue ;
592592 load_page_table (
593593 blk_coord,
594594 problem_shape,
595595 params.mainloop ,
596596 shared_storage.tensors ,
597597 pipeline_page_table, pipeline_pt_producer_state,
598- local_split_kv
598+ local_split_kv
599599 );
600600 }
601601 }
@@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
604604 CUTLASS_PRAGMA_NO_UNROLL
605605 for (; tile_scheduler.is_valid (); ++tile_scheduler) {
606606 auto blk_coord = tile_scheduler.get_block_coord ();
607- auto problem_shape = params.problem_shape ;
608- auto local_split_kv = params.split_kv ;
607+ auto problem_shape = params.problem_shape ;
608+ auto local_split_kv = params.split_kv ;
609609 if (params.mainloop .ptr_seq != nullptr ) {
610610 get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
611- if (params.ptr_split_kv != nullptr ) {
611+ if (params.ptr_split_kv != nullptr ) {
612612 local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
613613 }
614614 }
615- if (local_split_kv <= get<3 >(blk_coord))
615+ if (local_split_kv <= get<3 >(blk_coord))
616616 continue ;
617617 load_cpasync (
618618 blk_coord,
@@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
621621 params.mainloop_params ,
622622 shared_storage.tensors ,
623623 pipeline_load_qk, pipeline_load_qk_producer_state,
624- local_split_kv,
624+ local_split_kv,
625625 /* must be shared pipe */
626626 pipeline_page_table, pipeline_pt_consumer_state
627627 );
@@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
633633 CUTLASS_PRAGMA_NO_UNROLL
634634 for (; tile_scheduler.is_valid (); ++tile_scheduler) {
635635 auto blk_coord = tile_scheduler.get_block_coord ();
636- auto problem_shape = params.problem_shape ;
637- auto local_split_kv = params.split_kv ;
636+ auto problem_shape = params.problem_shape ;
637+ auto local_split_kv = params.split_kv ;
638638 if (params.mainloop .ptr_seq != nullptr ) {
639639 get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
640- if (params.ptr_split_kv != nullptr ) {
641- local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
642- }
640+ if (params.ptr_split_kv != nullptr ) {
641+ local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
642+ }
643643 }
644- if (local_split_kv <= get<3 >(blk_coord))
644+ if (local_split_kv <= get<3 >(blk_coord))
645645 continue ;
646646 load_tma</* paged= */ true >(
647647 blk_coord,
@@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
651651 shared_storage.tensors ,
652652 pipeline_load_qk, pipeline_load_qk_producer_state,
653653 pipeline_load_qk, pipeline_load_qk_producer_state,
654- local_split_kv
654+ local_split_kv
655655 );
656656 cutlass::arch::NamedBarrier ((kNumComputeWarps + kNumLoadWarps ) * NumThreadsPerWarp, kNamedBarrierEpilogue ).arrive_and_wait ();
657657 }
@@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
660660 CUTLASS_PRAGMA_NO_UNROLL
661661 for (; tile_scheduler.is_valid (); ++tile_scheduler) {
662662 auto blk_coord = tile_scheduler.get_block_coord ();
663- auto problem_shape = params.problem_shape ;
664- auto local_split_kv = params.split_kv ;
663+ auto problem_shape = params.problem_shape ;
664+ auto local_split_kv = params.split_kv ;
665665 if (params.mainloop .ptr_seq != nullptr ) {
666666 get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
667- if (params.ptr_split_kv != nullptr ) {
667+ if (params.ptr_split_kv != nullptr ) {
668668 local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
669- }
669+ }
670670 }
671- if (local_split_kv <= get<3 >(blk_coord))
671+ if (local_split_kv <= get<3 >(blk_coord))
672672 continue ;
673673 load_tma<false >(
674674 blk_coord,
@@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
678678 shared_storage.tensors ,
679679 pipeline_load_qk, pipeline_load_qk_producer_state,
680680 pipeline_load_qk, pipeline_load_qk_producer_state,
681- local_split_kv
681+ local_split_kv
682682 );
683683 cutlass::arch::NamedBarrier ((kNumComputeWarps + kNumLoadWarps ) * NumThreadsPerWarp, kNamedBarrierEpilogue ).arrive_and_wait ();
684684 }
@@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
694694 for (; tile_scheduler.is_valid (); ++tile_scheduler) {
695695 auto blk_coord = tile_scheduler.get_block_coord ();
696696 auto problem_shape = params.problem_shape ;
697- auto local_split_kv = params.split_kv ;
697+ auto local_split_kv = params.split_kv ;
698698 if (params.mainloop .ptr_seq != nullptr ) {
699699 get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
700700 if (params.ptr_split_kv != nullptr ) {
701701 local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
702702 }
703703 }
704- if (local_split_kv <= get<3 >(blk_coord))
704+ if (local_split_kv <= get<3 >(blk_coord))
705705 continue ;
706706 mma (blk_coord,
707707 problem_shape,
@@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
711711 pipeline_mma_s, pipeline_mma_s_producer_state,
712712 pipeline_p_mma, pipeline_p_mma_consumer_state,
713713 pipeline_mma_o, pipeline_mma_o_producer_state,
714- local_split_kv
714+ local_split_kv
715715 );
716716 }
717717 }
@@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
726726 for (; tile_scheduler.is_valid (); ++tile_scheduler) {
727727 auto blk_coord = tile_scheduler.get_block_coord ();
728728 auto problem_shape = params.problem_shape ;
729- auto split_kv = params.split_kv ;
730- auto local_split_kv = split_kv;
729+ auto split_kv = params.split_kv ;
730+ auto local_split_kv = split_kv;
731731 if (params.mainloop .ptr_seq != nullptr ) {
732732 get<1 >(problem_shape) = params.mainloop .ptr_seq [get<2 >(blk_coord)];
733- if (params.ptr_split_kv != nullptr ) {
733+ if (params.ptr_split_kv != nullptr ) {
734734 local_split_kv = params.ptr_split_kv [get<2 >(blk_coord)];
735735 }
736736 }
737- if (local_split_kv <= get<3 >(blk_coord))
737+ if (local_split_kv <= get<3 >(blk_coord))
738738 continue ;
739739 compute (
740740 blk_coord,
@@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
745745 pipeline_mma_s, pipeline_mma_s_consumer_state,
746746 pipeline_p_mma, pipeline_p_mma_producer_state,
747747 pipeline_mma_o, pipeline_mma_o_consumer_state,
748- local_split_kv
748+ local_split_kv
749749 );
750750 }
751751
@@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
19001900 cutlass::arch::NamedBarrier (
19011901 (kNumComputeWarps + kNumLoadWarps ) * NumThreadsPerWarp,
19021902 kNamedBarrierEpilogue
1903- ).arrive ();
1903+ ).arrive_and_wait ();
19041904
19051905 return ;
19061906 }
0 commit comments