Skip to content

Commit 151e69b

Browse files
fa MLA cp support
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent ed16d0f commit 151e69b

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
172172

173173

174174
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
175+
can_return_lse_for_decode: bool = True
175176

176177
def __init__(
177178
self,
@@ -239,7 +240,7 @@ def _forward_decode(
239240
# to prevent invalid grid configuration during graph capture.
240241
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
241242

242-
o = flash_attn_varlen_func(
243+
attn_out = flash_attn_varlen_func(
243244
q=q_pe,
244245
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
245246
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
@@ -251,9 +252,15 @@ def _forward_decode(
251252
block_table=attn_metadata.decode.block_table,
252253
softmax_scale=self.scale,
253254
causal=True,
255+
return_softmax_lse=self.need_to_return_lse_for_decode,
254256
fa_version=3, # only version 3 is supported
255257
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
256258
num_splits=attn_metadata.decode.max_num_splits,
257259
)
258-
259-
return self._v_up_proj(o)
260+
261+
if self.need_to_return_lse_for_decode:
262+
o, lse = attn_out
263+
return o, lse
264+
else:
265+
o = attn_out
266+
return o, None

vllm/v1/worker/gpu_model_runner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,9 +440,6 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
440440
return
441441

442442
if self.reorder_batch_threshold is not None:
443-
if self.dcp_world_size > 1:
444-
assert self.reorder_batch_threshold == 1, \
445-
"DCP not support reorder_batch_threshold > 1 now."
446443
reorder_batch_to_split_decodes_and_prefills(
447444
self.input_batch,
448445
scheduler_output,

0 commit comments

Comments
 (0)