204204from vllm .attention .ops .common import cp_lse_ag_out_rs
205205from vllm .attention .ops .merge_attn_states import merge_attn_states
206206from vllm .attention .utils .fa_utils import get_flash_attn_version
207- from vllm .config import VllmConfig
207+ from vllm .config import VllmConfig , get_current_vllm_config
208208from vllm .distributed .parallel_state import get_dcp_group , is_global_first_rank
209209from vllm .logger import init_logger
210210from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
@@ -436,6 +436,34 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
436436 """
437437 reorder_batch_threshold : ClassVar [int ] = 1
438438
439+ @staticmethod
440+ def determine_chunked_prefill_workspace_size (
441+ vllm_config : VllmConfig ) -> int :
442+ scheduler_config = vllm_config .scheduler_config
443+ cache_config = vllm_config .cache_config
444+ model_config = vllm_config .model_config
445+
446+ chunked_prefill_workspace_size = min (
447+ # Try for 8 full length request or at least 4 pages per-request
448+ max (8 * model_config .max_model_len ,
449+ 4 * scheduler_config .max_num_seqs * cache_config .block_size ),
450+ # For long-context models try not to over-allocate limiting
451+ # kv-cache space, limiting it to 64k tokens,
452+ # which would result in the workspace being:
453+ # 2*(576)*(64*1024) = 144mb
454+ # (assuming 576 MLA head dim, and fp16)
455+ # which would result in up-projected context being
456+ # 2*(192*128)*(64*1024) = 3gb
457+ # (assuming 192 QK head dim, 128 heads, and fp16)
458+ 64 * 1024 )
459+
460+ # Enforce that we enough for at least 1 page per request
461+ chunked_prefill_workspace_size = max (
462+ chunked_prefill_workspace_size ,
463+ scheduler_config .max_num_seqs * cache_config .block_size )
464+
465+ return chunked_prefill_workspace_size
466+
439467 def __init__ (self ,
440468 kv_cache_spec : AttentionSpec ,
441469 layer_names : list [str ],
@@ -448,7 +476,6 @@ def __init__(self,
448476 scheduler_config = vllm_config .scheduler_config
449477 self .model_config = vllm_config .model_config
450478 parallel_config = vllm_config .parallel_config
451- cache_config = vllm_config .cache_config
452479 self .compilation_config = vllm_config .compilation_config
453480 self .device = device
454481
@@ -468,22 +495,9 @@ def __init__(self,
468495 if self .aot_schedule :
469496 self .page_size = self .kv_cache_spec .block_size
470497
471- self .chunked_prefill_workspace_size = min (
472- # Max sure there is enough for 8 full length request or at least
473- # 4 pages of cache per request
474- max (8 * self .model_config .max_model_len ,
475- 4 * scheduler_config .max_num_seqs * cache_config .block_size ),
476- # For long-context models try not to over-allocate limiting
477- # kv-cache space, limiting it to 64k tokens,
478- # which would result in the workspace being:
479- # 2*(576)*(64*1024) = 144mb
480- # (assuming 576 MLA head dim, and fp16)
481- # which would result in up-projected context being
482- # 2*(192*128)*(64*1024) = 3gb
483- # (assuming 192 QK head dim, 128 heads, and fp16)
484- 64 * 1024 )
485- assert self .chunked_prefill_workspace_size >= \
486- scheduler_config .max_num_seqs * cache_config .block_size
498+ self .chunked_prefill_workspace_size = \
499+ self .determine_chunked_prefill_workspace_size (vllm_config )
500+
487501 if self .dcp_world_size > 1 :
488502 # Note(hc): The local kvcache is incomplete when DCP is triggered,
489503 # an additional kvcache allgather across the DCP group is therefore
@@ -999,6 +1013,10 @@ def __init__(
9991013
10001014 self .dcp_world_size : Optional [int ] = None
10011015
1016+ self .chunked_prefill_workspace_size = \
1017+ MLACommonMetadataBuilder .determine_chunked_prefill_workspace_size (
1018+ get_current_vllm_config ())
1019+
10021020 def _flash_attn_varlen_diff_headdims (self ,
10031021 q ,
10041022 k ,
@@ -1513,6 +1531,16 @@ def forward(
15131531 " for MLACommonImpl" )
15141532
15151533 if attn_metadata is None :
1534+ # During the profile run try to simulate to worse case output size
1535+ # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
1536+ # since this can be large
1537+ _ = torch .empty (
1538+ (self .chunked_prefill_workspace_size , self .num_heads ,
1539+ self .qk_nope_head_dim + self .v_head_dim ),
1540+ device = k_c_normed .device ,
1541+ dtype = k_c_normed .dtype ,
1542+ )
1543+
15161544 # The zero fill is required when used with DP + EP
15171545 # to ensure all ranks within a DP group compute the
15181546 # same expert outputs.
0 commit comments