@@ -619,34 +619,70 @@ def build(self,
619619
620620 if self .dcp_world_size > 1 :
621621 # init custom mask for interleave kv cache
622- mask_arr = []
623- q_lens = (qo_indptr_cpu [1 :] -
624- qo_indptr_cpu [:- 1 ]).cpu ().tolist ()
622+ # |-------total_lens----------|
623+ # |--context_lens--|--q_lens--|
624+ # Example: dcp_size=2, dcp_rank=0
625+ # For a SINGLE prefill seq, q_lens=3, total_lens=5
626+ # k_lens on RANK1 is (5 - 1 - 0) // 2 + 1 = 3
627+ # mask.shape = [q_lens, k_lens] = [3,3]
628+ # mask [[True, True, False],
629+ # [True, True, False],
630+ # [True, True, True]]
631+ dcp_rank = self .dcp_rank
632+ dcp_size = self .dcp_world_size
633+
634+ q_lens = (qo_indptr_cpu [1 :] - qo_indptr_cpu [:- 1 ]).to (
635+ dtype = torch .int64 , device = self .device )
625636 total_lens = seq_lens_cpu [prefill_start :prefill_start +
626- num_prefills ].to (
627- torch .int64 ).tolist ()
628- r = self .dcp_rank
629- p = self .dcp_world_size
630- for i in range (num_prefills ):
631- Q = int (q_lens [i ])
632- T = int (total_lens [i ])
633- if Q <= 0 :
634- mask_arr .append (torch .zeros (0 , dtype = torch .bool ))
635- continue
636- L = T - Q
637- rightmost = L + Q - 1
638- if rightmost < r :
639- mask_arr .append (torch .zeros (0 , dtype = torch .bool ))
640- continue
641- K = ((rightmost - r ) // p ) + 1
642- j = torch .arange (K )
643- t = torch .arange (Q )
644- upper = (L + t - r ) // p
645- upper = torch .clamp (upper , min = - 1 )
646- mask_i = (j .unsqueeze (0 ) <= upper .unsqueeze (1 )) & (
647- upper .unsqueeze (1 ) >= 0 )
648- mask_arr .append (mask_i .flatten ())
649- custom_mask = torch .cat (mask_arr , dim = 0 ).to (self .device )
637+ num_prefills ].to (dtype = torch .int64 ,
638+ device = self .device )
639+ context_lens = total_lens - q_lens
640+ # max indices for global sequences
641+ max_indices = total_lens - 1
642+ # if max_indices are smaller than dcp_rank,
643+ # current rank has no kv cache, is invalid,
644+ # the mask is skipped
645+ valid = (max_indices >= dcp_rank )
646+ assert torch .any (valid ), "There is no valid sequence"
647+
648+ # local kv lens on current dcp_rank
649+ k_lens = torch .div (max_indices - dcp_rank ,
650+ dcp_size ,
651+ rounding_mode = "floor" ) + 1
652+ k_lens = torch .where (
653+ valid ,
654+ k_lens ,
655+ torch .zeros_like (k_lens ))
656+ # vectorize operation
657+ # obtain the max length of all prefill reqs
658+ max_q = int (q_lens [valid ].max ().item ())
659+ max_k = int (k_lens [valid ].max ().item ())
660+ # generate local q and k indices
661+ q_indices = torch .arange (max_q , device = self .device )
662+ k_indices = torch .arange (max_k , device = self .device )
663+ # valid q and k indices of each reqs
664+ valid_q = valid [:, None ] & \
665+ (q_indices [None , :] < q_lens [:, None ])
666+ valid_k = valid [:, None ] & \
667+ (k_indices [None , :] < k_lens [:, None ])
668+ # where global q_indices >= global k_indices,
669+ # the mask is True
670+ # global q_indices = context_lens + local q_indices
671+ # global k_indices = local k_indcies * dcp_size + dcp_rank
672+ # ====> local k_indcies must be smaller or equal k_upper
673+ # k_upper=(context_lens + local q_indices - dcp_rank) // dcp_size
674+ k_upper = torch .div (
675+ context_lens [:, None ] + q_indices - dcp_rank ,
676+ dcp_size , rounding_mode = "floor" )
677+ k_upper = torch .where (
678+ valid_q ,
679+ torch .clamp (k_upper , min = - 1 ),
680+ k_upper .new_full (k_upper .shape , - 1 ))
681+ mask = (k_indices [None , None , :] <= k_upper [:, :, None ]) \
682+ & (k_upper [:, :, None ] >= 0 )
683+ valid_positions = valid_q [:, :, None ] & valid_k [:, None , :]
684+ # flashinfer backend needs flattened format
685+ custom_mask = torch .masked_select (mask , valid_positions )
650686 else :
651687 custom_mask = None
652688
0 commit comments