Skip to content

Commit 8e41390

Browse files
simon-molulmer
authored andcommitted
[Perf] Improve MLA on V1 (vllm-project#14540)
Signed-off-by: simon-mo <simon.mo@hey.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent 4416ea1 commit 8e41390

File tree

1 file changed

+41
-27
lines changed

1 file changed

+41
-27
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@
223223
from vllm.model_executor.layers.quantization.utils.quant_utils import (
224224
scaled_quantize)
225225
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
226+
from vllm.platforms import current_platform
226227
from vllm.utils import cdiv, round_down
227228

228229
try:
@@ -471,43 +472,46 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
471472
common_prefix_len: int) -> M:
472473
assert self._num_decodes + self._num_prefills == num_reqs
473474

475+
# Note(simon): be careful about the CPU <> GPU memory movement in this
476+
# function. We should avoid GPU -> CPU sync as much as possible because
477+
# it blocks on all previous kernels.
474478
device = self.runner.device
475-
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
476-
device, non_blocking=True)
477-
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device,
478-
non_blocking=True)
479479
block_table = (
480480
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
481+
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
482+
device, non_blocking=True)
481483
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
482484
device, non_blocking=True).long()
483485
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
484486
device, non_blocking=True).long()
485487

488+
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
489+
seq_lens = seq_lens_cpu.to(device, non_blocking=True)
490+
max_query_len = seq_lens_cpu.max().item()
491+
486492
prefill_metadata = None
487493
if self._num_prefills > 0:
488494
reqs_start = self._num_decodes # prefill_start
489495
tokens_start = self._num_decode_tokens
490496

491497
context_lens_cpu = self.runner.input_batch.\
492498
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
493-
context_lens = context_lens_cpu.to(device, non_blocking=True)
499+
max_context_len_cpu = context_lens_cpu.max().item()
500+
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
494501

495502
chunked_context_metadata = None
496503
if self.chunked_prefill_enabled and self._num_prefills > 0 \
497-
and context_lens.max() > 0:
504+
and max_context_len_cpu > 0:
498505
# NOTE: it is recommend you read the `Chunked Prefill` section
499506
# in the comment at the top of the file before trying to
500507
# understand the following code
501508

502-
num_prefills_with_context = (context_lens > 0).sum().item()
503-
504509
# currently we allocate an equal amount of workspace for each
505510
# prefill in the batch, we could probably use a more advanced
506511
# algorithm here and allocate more workspace to prefills with
507512
# longer context lengths
508-
max_context_chunk = \
509-
self.chunked_prefill_workspace_size \
510-
// num_prefills_with_context
513+
max_context_chunk = (self.chunked_prefill_workspace_size //
514+
num_prefills_with_context_cpu)
511515

512516
# align max_context_chunk to page_size by rounding down,
513517
# currently the `gather_cache` kernel cannot handle
@@ -516,30 +520,35 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
516520
self.page_size)
517521

518522
assert max_context_chunk > 0
519-
num_chunks = cdiv(context_lens.max(), max_context_chunk)
523+
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
520524

521525
# if `max_context_chunk = 256`, `num_chunks = 3`, and
522526
# `num_prefills_with_context = 4`, create a tensor that looks
523527
# like
524528
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
529+
# Note(simon): this is done in CPU because of downstream's
530+
# of `to_list`.
525531
chunk_starts = \
526-
torch.arange(num_chunks, device=device, dtype=torch.int32) \
532+
torch.arange(num_chunks, dtype=torch.int32) \
527533
.unsqueeze(1).expand(-1, self._num_prefills) \
528534
* max_context_chunk
529-
chunk_ends = torch.min(context_lens.unsqueeze(0),
535+
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
530536
chunk_starts + max_context_chunk)
531537
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
532-
_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
533-
torch.int32)
534-
zero = torch.zeros(num_chunks,
535-
dtype=torch.int32,
536-
device=device).unsqueeze(-1)
538+
539+
cu_seq_lens_cpu = torch.zeros(num_chunks,
540+
self._num_prefills + 1,
541+
dtype=torch.int32,
542+
pin_memory=True)
543+
torch.cumsum(chunk_seq_lens,
544+
dim=1,
545+
out=cu_seq_lens_cpu[:, 1:],
546+
dtype=torch.int32)
537547

538548
chunked_context_metadata = \
539549
MLACommonPrefillMetadata.ChunkedContextMetadata(
540-
cu_seq_lens=torch.cat(
541-
[zero, _chunk_cu_seq_lens], dim=1),
542-
starts=chunk_starts,
550+
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
551+
starts=chunk_starts.to(device, non_blocking=True),
543552
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
544553
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
545554
workspace=self.chunked_prefill_workspace,
@@ -553,7 +562,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
553562
block_table=block_table[reqs_start:, ...],
554563
query_start_loc=query_start_loc[reqs_start:] -
555564
query_start_loc[reqs_start],
556-
max_query_len=seq_lens[reqs_start:].max().item(),
565+
max_query_len=max_query_len,
557566
chunked_context=chunked_context_metadata,
558567
)
559568

@@ -629,7 +638,9 @@ def __init__(
629638
# already inside an attention custom op), pull out the forward
630639
# method from the rotary embedding and call it directly
631640
# TODO(lucas): we should probably find a cleaner way to do this
632-
self.rotary_emb = rotary_emb._forward_method
641+
self.rotary_emb = rotary_emb.forward_native
642+
if current_platform.is_cuda():
643+
self.rotary_emb = rotary_emb.forward_cuda
633644

634645
self.q_proj = q_proj
635646
self.kv_b_proj = kv_b_proj
@@ -1043,17 +1054,20 @@ def forward(
10431054
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
10441055
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
10451056
.view(-1, self.num_heads, self.qk_rope_head_dim)
1057+
10461058
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
1047-
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
1059+
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
1060+
decode_k_pe)
10481061

10491062
if has_prefill:
10501063
assert attn_metadata.prefill is not None
10511064
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
10521065
.view(-1, self.num_heads, self.qk_head_dim)
10531066
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
1067+
10541068
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
1055-
attn_metadata.prefill.input_positions, prefill_q_pe,
1056-
prefill_k_pe)
1069+
attn_metadata.prefill.input_positions,
1070+
prefill_q_pe.contiguous(), prefill_k_pe)
10571071

10581072
# write the latent and rope to kv cache
10591073
if kv_cache.numel() > 0:

0 commit comments

Comments
 (0)