@@ -87,8 +87,8 @@ def flash_mla():
8787
8888
8989@torch .inference_mode ()
90- def run_flash_infer (q , block_table , blocked_k , max_seqlen_pad , block_size , b , s_q , cache_seqlens ,
91- h_q , h_kv , d , dv , causal , dtype ):
90+ def run_flashinfer (q , block_table , blocked_k , max_seqlen_pad , block_size , b , s_q , cache_seqlens ,
91+ h_q , h_kv , d , dv , causal , dtype ):
9292 # pip install flashinfer-python
9393 import flashinfer
9494 assert d > dv , "mla with rope dim should be larger than no rope dim"
@@ -128,7 +128,7 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
128128 blocked_k .dtype ,
129129 )
130130
131- def flash_infer ():
131+ def flashinfer ():
132132 output , lse = mla_wrapper .run (
133133 q_nope .view (- 1 , h_q , dv ),
134134 q_pe .view (- 1 , h_q , d - dv ),
@@ -137,8 +137,8 @@ def flash_infer():
137137 return_lse = True )
138138 return output .view (b , - 1 , h_q , dv ), lse .view (b , h_q , 1 )
139139
140- out_flash , lse_flash = flash_infer ()
141- t = triton .testing .do_bench (flash_infer )
140+ out_flash , lse_flash = flashinfer ()
141+ t = triton .testing .do_bench (flashinfer )
142142 return out_flash , lse_flash , t
143143
144144
@@ -459,7 +459,7 @@ def flash_mla_tilelang():
459459 "torch" : run_torch_mla ,
460460 "tilelang" : run_flash_mla_tilelang ,
461461 "flash_mla" : run_flash_mla ,
462- "flash_infer " : run_flash_infer ,
462+ "flashinfer " : run_flashinfer ,
463463 "flash_mla_triton" : run_flash_mla_triton ,
464464}
465465
@@ -496,9 +496,9 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
496496 s_q , cache_seqlens , h_q , h_kv , d , dv , causal , dtype )
497497
498498 torch .testing .assert_close (out_b .float (), out_a .float (), atol = 1e-2 , rtol = 1e-2 ), "out"
499- if target not in ["flash_infer " , "flash_mla_triton" , "tilelang"
500- ] and baseline not in ["flash_infer " , "flash_mla_triton" , "tilelang" ]:
501- # flash_infer has a different lse return value
499+ if target not in ["flashinfer " , "flash_mla_triton" , "tilelang"
500+ ] and baseline not in ["flashinfer " , "flash_mla_triton" , "tilelang" ]:
501+ # flashinfer has a different lse return value
502502 # flash_mla_triton and flash_mla_tilelang doesn't return lse
503503 torch .testing .assert_close (lse_b .float (), lse_a .float (), atol = 1e-2 , rtol = 1e-2 ), "lse"
504504
@@ -554,7 +554,7 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
554554 "torch" ,
555555 "tilelang" ,
556556 "flash_mla" ,
557- "flash_infer " ,
557+ "flashinfer " ,
558558 "flash_mla_triton" ,
559559]
560560
0 commit comments