Skip to content

Commit 5a107b7

Browse files
committed
vectorize mask construction
Signed-off-by: gaojc <1055866782@qq.com>
1 parent 1c93a20 commit 5a107b7

File tree

1 file changed

+63
-27
lines changed

1 file changed

+63
-27
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -619,34 +619,70 @@ def build(self,
619619

620620
if self.dcp_world_size > 1:
621621
# init custom mask for interleave kv cache
622-
mask_arr = []
623-
q_lens = (qo_indptr_cpu[1:] -
624-
qo_indptr_cpu[:-1]).cpu().tolist()
622+
# |-------total_lens----------|
623+
# |--context_lens--|--q_lens--|
624+
# Example: dcp_size=2, dcp_rank=0
625+
# For a SINGLE prefill seq, q_lens=3, total_lens=5
626+
# k_lens on RANK1 is (5 - 1 - 0) // 2 + 1 = 3
627+
# mask.shape = [q_lens, k_lens] = [3,3]
628+
# mask [[True, True, False],
629+
# [True, True, False],
630+
# [True, True, True]]
631+
dcp_rank = self.dcp_rank
632+
dcp_size = self.dcp_world_size
633+
634+
q_lens = (qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]).to(
635+
dtype=torch.int64, device=self.device)
625636
total_lens = seq_lens_cpu[prefill_start:prefill_start +
626-
num_prefills].to(
627-
torch.int64).tolist()
628-
r = self.dcp_rank
629-
p = self.dcp_world_size
630-
for i in range(num_prefills):
631-
Q = int(q_lens[i])
632-
T = int(total_lens[i])
633-
if Q <= 0:
634-
mask_arr.append(torch.zeros(0, dtype=torch.bool))
635-
continue
636-
L = T - Q
637-
rightmost = L + Q - 1
638-
if rightmost < r:
639-
mask_arr.append(torch.zeros(0, dtype=torch.bool))
640-
continue
641-
K = ((rightmost - r) // p) + 1
642-
j = torch.arange(K)
643-
t = torch.arange(Q)
644-
upper = (L + t - r) // p
645-
upper = torch.clamp(upper, min=-1)
646-
mask_i = (j.unsqueeze(0) <= upper.unsqueeze(1)) & (
647-
upper.unsqueeze(1) >= 0)
648-
mask_arr.append(mask_i.flatten())
649-
custom_mask = torch.cat(mask_arr, dim=0).to(self.device)
637+
num_prefills].to(dtype=torch.int64,
638+
device=self.device)
639+
context_lens = total_lens - q_lens
640+
# max indices for global sequences
641+
max_indices = total_lens - 1
642+
# if max_indices are smaller than dcp_rank,
643+
# current rank has no kv cache, is invalid,
644+
# the mask is skipped
645+
valid = (max_indices >= dcp_rank)
646+
assert torch.any(valid), "There is no valid sequence"
647+
648+
# local kv lens on current dcp_rank
649+
k_lens = torch.div(max_indices - dcp_rank,
650+
dcp_size,
651+
rounding_mode="floor") + 1
652+
k_lens = torch.where(
653+
valid,
654+
k_lens,
655+
torch.zeros_like(k_lens))
656+
# vectorize operation
657+
# obtain the max length of all prefill reqs
658+
max_q = int(q_lens[valid].max().item())
659+
max_k = int(k_lens[valid].max().item())
660+
# generate local q and k indices
661+
q_indices = torch.arange(max_q, device=self.device)
662+
k_indices = torch.arange(max_k, device=self.device)
663+
# valid q and k indices of each reqs
664+
valid_q = valid[:, None] & \
665+
(q_indices[None, :] < q_lens[:, None])
666+
valid_k = valid[:, None] & \
667+
(k_indices[None, :] < k_lens[:, None])
668+
# where global q_indices >= global k_indices,
669+
# the mask is True
670+
# global q_indices = context_lens + local q_indices
671+
# global k_indices = local k_indcies * dcp_size + dcp_rank
672+
# ====> local k_indcies must be smaller or equal k_upper
673+
# k_upper=(context_lens + local q_indices - dcp_rank) // dcp_size
674+
k_upper = torch.div(
675+
context_lens[:, None] + q_indices - dcp_rank,
676+
dcp_size, rounding_mode="floor")
677+
k_upper = torch.where(
678+
valid_q,
679+
torch.clamp(k_upper, min=-1),
680+
k_upper.new_full(k_upper.shape, -1))
681+
mask = (k_indices[None, None, :] <= k_upper[:, :, None]) \
682+
& (k_upper[:, :, None] >= 0)
683+
valid_positions = valid_q[:, :, None] & valid_k[:, None, :]
684+
# flashinfer backend needs flattened format
685+
custom_mask = torch.masked_select(mask, valid_positions)
650686
else:
651687
custom_mask = None
652688

0 commit comments

Comments
 (0)