@@ -753,6 +753,9 @@ def set_loop_idx_rtp():
753753
754754 for q_block_idx in range (num_q_block_per_pipeline ):
755755
756+ # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete.
757+ tg = rt .task_group ()
758+
756759 if number_of_pipelines > 6 :
757760 rt .fill (
758761 inQ .prod (),
@@ -761,6 +764,7 @@ def set_loop_idx_rtp():
761764 2 * head_idx * num_q_block_per_pipeline + q_block_idx * 2
762765 ],
763766 placement = Tile (col = 4 , row = 0 ),
767+ task_group = tg ,
764768 )
765769 rt .fill (
766770 inQ2 .prod (),
@@ -771,21 +775,31 @@ def set_loop_idx_rtp():
771775 + 1
772776 ],
773777 placement = Tile (col = 4 , row = 0 ),
778+ task_group = tg ,
774779 )
775780 else :
776781 rt .fill (
777782 inQ .prod (),
778783 Q ,
779784 tap = Q_tiles [head_idx * num_q_block_per_pipeline + q_block_idx ],
780785 placement = Tile (col = 4 , row = 0 ),
786+ task_group = tg ,
781787 )
782788
783789 # Thow on bd containing the full K and V in the object fifo, then does it transfer cunks of inKV size at the time?
784790 rt .fill (
785- inK .prod (), K , tap = K_tiles [head_idx ], placement = Tile (col = 5 , row = 0 )
791+ inK .prod (),
792+ K ,
793+ tap = K_tiles [head_idx ],
794+ placement = Tile (col = 5 , row = 0 ),
795+ task_group = tg ,
786796 )
787797 rt .fill (
788- inV .prod (), V , tap = V_tiles [head_idx ], placement = Tile (col = 6 , row = 0 )
798+ inV .prod (),
799+ V ,
800+ tap = V_tiles [head_idx ],
801+ placement = Tile (col = 6 , row = 0 ),
802+ task_group = tg ,
789803 )
790804
791805 if number_of_pipelines > 6 :
@@ -797,6 +811,7 @@ def set_loop_idx_rtp():
797811 ],
798812 wait = True ,
799813 placement = Tile (col = 7 , row = 0 ),
814+ task_group = tg ,
800815 )
801816 rt .drain (
802817 memO2 .cons (),
@@ -808,6 +823,7 @@ def set_loop_idx_rtp():
808823 ],
809824 wait = True ,
810825 placement = Tile (col = 7 , row = 0 ),
826+ task_group = tg ,
811827 )
812828 else :
813829 rt .drain (
@@ -816,8 +832,11 @@ def set_loop_idx_rtp():
816832 tap = O_tiles [head_idx * num_q_block_per_pipeline + q_block_idx ],
817833 wait = True ,
818834 placement = Tile (col = 7 , row = 0 ),
835+ task_group = tg ,
819836 )
820837
838+ rt .finish_task_group (tg )
839+
821840 # Create the program from the device type and runtime
822841 if dev == "npu" :
823842 dev_ty = NPU1Col1 ()
0 commit comments