@@ -172,6 +172,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
172172
173173
174174class FlashAttnMLAImpl (MLACommonImpl [FlashAttnMLAMetadata ]):
175+ can_return_lse_for_decode : bool = True
175176
176177 def __init__ (
177178 self ,
@@ -239,7 +240,7 @@ def _forward_decode(
239240 # to prevent invalid grid configuration during graph capture.
240241 max_seqlen_q = max (attn_metadata .decode .max_query_len , 1 )
241242
242- o = flash_attn_varlen_func (
243+ attn_out = flash_attn_varlen_func (
243244 q = q_pe ,
244245 k = k_pe_cache .unsqueeze (- 2 ), # Add head dim of 1
245246 v = kv_c_cache .unsqueeze (- 2 ), # Add head dim of 1
@@ -251,9 +252,15 @@ def _forward_decode(
251252 block_table = attn_metadata .decode .block_table ,
252253 softmax_scale = self .scale ,
253254 causal = True ,
255+ return_softmax_lse = self .need_to_return_lse_for_decode ,
254256 fa_version = 3 , # only version 3 is supported
255257 scheduler_metadata = attn_metadata .decode .scheduler_metadata ,
256258 num_splits = attn_metadata .decode .max_num_splits ,
257259 )
258-
259- return self ._v_up_proj (o )
260+
261+ if self .need_to_return_lse_for_decode :
262+ o , lse = attn_out
263+ return o , lse
264+ else :
265+ o = attn_out
266+ return o , None
0 commit comments