1717 get_flash_attn_version ,
1818)
1919from vllm .config import VllmConfig
20- from vllm .distributed .parallel_state import get_dcp_group
2120from vllm .logger import init_logger
2221from vllm .v1 .attention .backends .mla .common import (
2322 MLACommonBackend ,
@@ -107,12 +106,6 @@ def __init__(
107106 # pre-allocated during capture.
108107 self .max_num_splits = envs .VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
109108
110- # TODO(lucas): Until we add support for the DCP custom masking we need
111- # to restrict decodes to q_len == 1 when DCP is enabled.
112- self .reorder_batch_threshold = (
113- 1 if get_dcp_group ().world_size > 1 else self .reorder_batch_threshold
114- )
115-
116109 def _schedule_decode (
117110 self , num_reqs , cu_query_lens , max_query_len , seqlens , max_seq_len , causal
118111 ):
@@ -121,7 +114,7 @@ def _schedule_decode(
121114 batch_size = num_reqs ,
122115 max_seqlen_q = max_query_len ,
123116 max_seqlen_k = max_seq_len ,
124- num_heads_q = self .num_heads ,
117+ num_heads_q = self .num_heads * self . dcp_world_size ,
125118 num_heads_kv = 1 ,
126119 headdim = self .mla_dims .qk_rope_head_dim ,
127120 cache_seqlens = seqlens ,
@@ -142,10 +135,11 @@ def _build_decode(
142135 query_start_loc_cpu : torch .Tensor ,
143136 query_start_loc_device : torch .Tensor ,
144137 num_decode_tokens : int ,
138+ dcp_tot_seq_lens_device : Optional [torch .Tensor ],
145139 ) -> FlashAttnMLADecodeMetadata :
146140 query_lens_cpu = query_start_loc_cpu [1 :] - query_start_loc_cpu [:- 1 ]
147141 max_query_len = query_lens_cpu .max ().item ()
148- max_seq_len = seq_lens_cpu .max ().item ()
142+ max_seq_len = seq_lens_device .max ().item ()
149143
150144 scheduler_metadata = self ._schedule_decode (
151145 num_reqs = seq_lens_cpu .numel (),
@@ -188,6 +182,7 @@ def _build_decode(
188182 max_seq_len = max_seq_len ,
189183 scheduler_metadata = scheduler_metadata ,
190184 max_num_splits = max_num_splits ,
185+ dcp_tot_seq_lens = dcp_tot_seq_lens_device ,
191186 )
192187
193188
@@ -289,6 +284,9 @@ def _forward_decode(
289284 fa_version = 3 , # only version 3 is supported
290285 scheduler_metadata = attn_metadata .decode .scheduler_metadata ,
291286 num_splits = attn_metadata .decode .max_num_splits ,
287+ cp_world_size = self .dcp_world_size ,
288+ cp_rank = self .dcp_rank ,
289+ cp_tot_seqused_k = attn_metadata .decode .dcp_tot_seq_lens ,
292290 )
293291
294292 if self .need_to_return_lse_for_decode :
0 commit comments