204204from  vllm .attention .ops .common  import  cp_lse_ag_out_rs 
205205from  vllm .attention .ops .merge_attn_states  import  merge_attn_states 
206206from  vllm .attention .utils .fa_utils  import  get_flash_attn_version 
207- from  vllm .config  import  VllmConfig 
207+ from  vllm .config  import  VllmConfig ,  get_current_vllm_config 
208208from  vllm .distributed .parallel_state  import  get_dcp_group , is_global_first_rank 
209209from  vllm .logger  import  init_logger 
210210from  vllm .model_executor .layers .linear  import  (ColumnParallelLinear ,
@@ -436,6 +436,34 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
436436    """ 
437437    reorder_batch_threshold : ClassVar [int ] =  1 
438438
439+     @staticmethod  
440+     def  determine_chunked_prefill_workspace_size (
441+             vllm_config : VllmConfig ) ->  int :
442+         scheduler_config  =  vllm_config .scheduler_config 
443+         cache_config  =  vllm_config .cache_config 
444+         model_config  =  vllm_config .model_config 
445+ 
446+         chunked_prefill_workspace_size  =  min (
447+             # Try for 8 full length request or at least 4 pages per-request 
448+             max (8  *  model_config .max_model_len ,
449+                 4  *  scheduler_config .max_num_seqs  *  cache_config .block_size ),
450+             # For long-context models try not to over-allocate limiting 
451+             # kv-cache space, limiting it to 64k tokens, 
452+             # which would result in the workspace being: 
453+             #   2*(576)*(64*1024) = 144mb 
454+             # (assuming 576 MLA head dim, and fp16) 
455+             # which would result in up-projected context being 
456+             #   2*(192*128)*(64*1024) = 3gb 
457+             # (assuming 192 QK head dim, 128 heads, and fp16) 
458+             64  *  1024 )
459+ 
460+         # Enforce that we enough for at least 1 page per request 
461+         chunked_prefill_workspace_size  =  max (
462+             chunked_prefill_workspace_size ,
463+             scheduler_config .max_num_seqs  *  cache_config .block_size )
464+ 
465+         return  chunked_prefill_workspace_size 
466+ 
439467    def  __init__ (self ,
440468                 kv_cache_spec : AttentionSpec ,
441469                 layer_names : list [str ],
@@ -448,7 +476,6 @@ def __init__(self,
448476        scheduler_config  =  vllm_config .scheduler_config 
449477        self .model_config  =  vllm_config .model_config 
450478        parallel_config  =  vllm_config .parallel_config 
451-         cache_config  =  vllm_config .cache_config 
452479        self .compilation_config  =  vllm_config .compilation_config 
453480        self .device  =  device 
454481
@@ -468,22 +495,9 @@ def __init__(self,
468495        if  self .aot_schedule :
469496            self .page_size  =  self .kv_cache_spec .block_size 
470497
471-         self .chunked_prefill_workspace_size  =  min (
472-             # Max sure there is enough for 8 full length request or at least 
473-             # 4 pages of cache per request 
474-             max (8  *  self .model_config .max_model_len ,
475-                 4  *  scheduler_config .max_num_seqs  *  cache_config .block_size ),
476-             # For long-context models try not to over-allocate limiting 
477-             # kv-cache space, limiting it to 64k tokens, 
478-             # which would result in the workspace being: 
479-             #   2*(576)*(64*1024) = 144mb 
480-             # (assuming 576 MLA head dim, and fp16) 
481-             # which would result in up-projected context being 
482-             #   2*(192*128)*(64*1024) = 3gb 
483-             # (assuming 192 QK head dim, 128 heads, and fp16) 
484-             64  *  1024 )
485-         assert  self .chunked_prefill_workspace_size  >=  \
486-             scheduler_config .max_num_seqs  *  cache_config .block_size 
498+         self .chunked_prefill_workspace_size  =  \
499+             self .determine_chunked_prefill_workspace_size (vllm_config )
500+ 
487501        if  self .dcp_world_size  >  1 :
488502            # Note(hc): The local kvcache is incomplete when DCP is triggered, 
489503            # an additional kvcache allgather across the DCP group is therefore 
@@ -999,6 +1013,10 @@ def __init__(
9991013
10001014        self .dcp_world_size : Optional [int ] =  None 
10011015
1016+         self .chunked_prefill_workspace_size  =  \
1017+             MLACommonMetadataBuilder .determine_chunked_prefill_workspace_size (
1018+             get_current_vllm_config ())
1019+ 
10021020    def  _flash_attn_varlen_diff_headdims (self ,
10031021                                         q ,
10041022                                         k ,
@@ -1513,6 +1531,16 @@ def forward(
15131531                " for MLACommonImpl" )
15141532
15151533        if  attn_metadata  is  None :
1534+             # During the profile run try to simulate to worse case output size 
1535+             # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` 
1536+             # since this can be large 
1537+             _  =  torch .empty (
1538+                 (self .chunked_prefill_workspace_size , self .num_heads ,
1539+                  self .qk_nope_head_dim  +  self .v_head_dim ),
1540+                 device = k_c_normed .device ,
1541+                 dtype = k_c_normed .dtype ,
1542+             )
1543+ 
15161544            # The zero fill is required when used with DP + EP 
15171545            # to ensure all ranks within a DP group compute the 
15181546            # same expert outputs. 
0 commit comments