Skip to content

Commit cd3bafa

Browse files
accuracy fix
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 151e69b commit cd3bafa

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,11 @@ def _forward_decode(
257257
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
258258
num_splits=attn_metadata.decode.max_num_splits,
259259
)
260-
260+
261261
if self.need_to_return_lse_for_decode:
262262
o, lse = attn_out
263-
return o, lse
263+
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
264+
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
264265
else:
265266
o = attn_out
266267
return o, None

0 commit comments

Comments
 (0)