Skip to content

Commit dbaf50c

Browse files
LucasWilkinsonxuebwang-amd
authored andcommitted
[BugFix] Fix MLA assert with CUTLASS MLA (vllm-project#25478)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent e996606 commit dbaf50c

File tree

1 file changed

+46
-18
lines changed

1 file changed

+46
-18
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@
204204
from vllm.attention.ops.common import cp_lse_ag_out_rs
205205
from vllm.attention.ops.merge_attn_states import merge_attn_states
206206
from vllm.attention.utils.fa_utils import get_flash_attn_version
207-
from vllm.config import VllmConfig
207+
from vllm.config import VllmConfig, get_current_vllm_config
208208
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
209209
from vllm.logger import init_logger
210210
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -436,6 +436,34 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
436436
"""
437437
reorder_batch_threshold: ClassVar[int] = 1
438438

439+
@staticmethod
440+
def determine_chunked_prefill_workspace_size(
441+
vllm_config: VllmConfig) -> int:
442+
scheduler_config = vllm_config.scheduler_config
443+
cache_config = vllm_config.cache_config
444+
model_config = vllm_config.model_config
445+
446+
chunked_prefill_workspace_size = min(
447+
# Try for 8 full length request or at least 4 pages per-request
448+
max(8 * model_config.max_model_len,
449+
4 * scheduler_config.max_num_seqs * cache_config.block_size),
450+
# For long-context models try not to over-allocate limiting
451+
# kv-cache space, limiting it to 64k tokens,
452+
# which would result in the workspace being:
453+
# 2*(576)*(64*1024) = 144mb
454+
# (assuming 576 MLA head dim, and fp16)
455+
# which would result in up-projected context being
456+
# 2*(192*128)*(64*1024) = 3gb
457+
# (assuming 192 QK head dim, 128 heads, and fp16)
458+
64 * 1024)
459+
460+
# Enforce that we enough for at least 1 page per request
461+
chunked_prefill_workspace_size = max(
462+
chunked_prefill_workspace_size,
463+
scheduler_config.max_num_seqs * cache_config.block_size)
464+
465+
return chunked_prefill_workspace_size
466+
439467
def __init__(self,
440468
kv_cache_spec: AttentionSpec,
441469
layer_names: list[str],
@@ -448,7 +476,6 @@ def __init__(self,
448476
scheduler_config = vllm_config.scheduler_config
449477
self.model_config = vllm_config.model_config
450478
parallel_config = vllm_config.parallel_config
451-
cache_config = vllm_config.cache_config
452479
self.compilation_config = vllm_config.compilation_config
453480
self.device = device
454481

@@ -468,22 +495,9 @@ def __init__(self,
468495
if self.aot_schedule:
469496
self.page_size = self.kv_cache_spec.block_size
470497

471-
self.chunked_prefill_workspace_size = min(
472-
# Max sure there is enough for 8 full length request or at least
473-
# 4 pages of cache per request
474-
max(8 * self.model_config.max_model_len,
475-
4 * scheduler_config.max_num_seqs * cache_config.block_size),
476-
# For long-context models try not to over-allocate limiting
477-
# kv-cache space, limiting it to 64k tokens,
478-
# which would result in the workspace being:
479-
# 2*(576)*(64*1024) = 144mb
480-
# (assuming 576 MLA head dim, and fp16)
481-
# which would result in up-projected context being
482-
# 2*(192*128)*(64*1024) = 3gb
483-
# (assuming 192 QK head dim, 128 heads, and fp16)
484-
64 * 1024)
485-
assert self.chunked_prefill_workspace_size >= \
486-
scheduler_config.max_num_seqs * cache_config.block_size
498+
self.chunked_prefill_workspace_size = \
499+
self.determine_chunked_prefill_workspace_size(vllm_config)
500+
487501
if self.dcp_world_size > 1:
488502
# Note(hc): The local kvcache is incomplete when DCP is triggered,
489503
# an additional kvcache allgather across the DCP group is therefore
@@ -999,6 +1013,10 @@ def __init__(
9991013

10001014
self.dcp_world_size: Optional[int] = None
10011015

1016+
self.chunked_prefill_workspace_size = \
1017+
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
1018+
get_current_vllm_config())
1019+
10021020
def _flash_attn_varlen_diff_headdims(self,
10031021
q,
10041022
k,
@@ -1513,6 +1531,16 @@ def forward(
15131531
" for MLACommonImpl")
15141532

15151533
if attn_metadata is None:
1534+
# During the profile run try to simulate to worse case output size
1535+
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
1536+
# since this can be large
1537+
_ = torch.empty(
1538+
(self.chunked_prefill_workspace_size, self.num_heads,
1539+
self.qk_nope_head_dim + self.v_head_dim),
1540+
device=k_c_normed.device,
1541+
dtype=k_c_normed.dtype,
1542+
)
1543+
15161544
# The zero fill is required when used with DP + EP
15171545
# to ensure all ranks within a DP group compute the
15181546
# same expert outputs.

0 commit comments

Comments
 (0)