1111from vllm .attention .utils .fa_utils import (flash_attn_supports_mla ,
1212 get_flash_attn_version )
1313from vllm .config import VllmConfig
14+ from vllm .distributed .parallel_state import get_dcp_group
1415from vllm .logger import init_logger
1516from vllm .v1 .attention .backends .mla .common import (MLACommonBackend ,
1617 MLACommonDecodeMetadata ,
@@ -98,6 +99,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9899 # pre-allocated during capture.
99100 self .max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
100101
102+ # TODO(lucas): Until we add support for the DCP custom masking we need
103+ # to restrict decodes to q_len == 1 when DCP is enabled.
104+ self .__class__ .reorder_batch_threshold = 1 \
105+ if get_dcp_group ().world_size > 1 else self .reorder_batch_threshold
106+
101107 def _schedule_decode (self , num_reqs , cu_query_lens , max_query_len , seqlens ,
102108 max_seq_len , causal ):
103109 if self .fa_aot_schedule :
@@ -172,6 +178,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
172178
173179
174180class FlashAttnMLAImpl (MLACommonImpl [FlashAttnMLAMetadata ]):
181+ can_return_lse_for_decode : bool = True
175182
176183 def __init__ (
177184 self ,
@@ -239,7 +246,7 @@ def _forward_decode(
239246 # to prevent invalid grid configuration during graph capture.
240247 max_seqlen_q = max (attn_metadata .decode .max_query_len , 1 )
241248
242- o = flash_attn_varlen_func (
249+ attn_out = flash_attn_varlen_func (
243250 q = q_pe ,
244251 k = k_pe_cache .unsqueeze (- 2 ), # Add head dim of 1
245252 v = kv_c_cache .unsqueeze (- 2 ), # Add head dim of 1
@@ -251,9 +258,16 @@ def _forward_decode(
251258 block_table = attn_metadata .decode .block_table ,
252259 softmax_scale = self .scale ,
253260 causal = True ,
261+ return_softmax_lse = self .need_to_return_lse_for_decode ,
254262 fa_version = 3 , # only version 3 is supported
255263 scheduler_metadata = attn_metadata .decode .scheduler_metadata ,
256264 num_splits = attn_metadata .decode .max_num_splits ,
257265 )
258266
259- return self ._v_up_proj (o )
267+ if self .need_to_return_lse_for_decode :
268+ o , lse = attn_out
269+ # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
270+ return o , lse .transpose (0 , 1 ) # [ H, B ] -> [ B, H ]
271+ else :
272+ o = attn_out
273+ return o , None
0 commit comments