@@ -1211,13 +1211,18 @@ def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q,
12111211 k , v , return_softmax_lse ):
12121212 assert isinstance (prefill , FlashInferPrefillMetadata )
12131213 assert prefill .prefill_main is not None
1214- return prefill .prefill_main .run (
1214+ ret = prefill .prefill_main .run (
12151215 q = q ,
12161216 k = k ,
12171217 v = v ,
12181218 return_lse = return_softmax_lse ,
12191219 )
12201220
1221+ if isinstance (ret , tuple ):
1222+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1223+ return ret [0 ], ret [1 ].transpose (0 , 1 ).contiguous ()
1224+ return ret
1225+
12211226 def _run_prefill_new_tokens_cudnn (self , prefill : MLACommonPrefillMetadata ,
12221227 q , k , v , return_softmax_lse ):
12231228 assert isinstance (prefill , CudnnPrefillMetadata )
@@ -1260,12 +1265,14 @@ def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
12601265 def _run_prefill_context_chunk_fi (self , prefill : MLACommonPrefillMetadata ,
12611266 chunk_idx : int , q , k , v ):
12621267 assert isinstance (prefill , FlashInferPrefillMetadata )
1263- return prefill .prefill_chunks [chunk_idx ].run (
1268+ attn_out , lse = prefill .prefill_chunks [chunk_idx ].run (
12641269 q = q ,
12651270 k = k ,
12661271 v = v ,
12671272 return_lse = True ,
12681273 )
1274+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1275+ return attn_out , lse .transpose (0 , 1 ).contiguous ()
12691276
12701277 def _run_prefill_context_chunk_cudnn (self ,
12711278 prefill : MLACommonPrefillMetadata ,
0 commit comments