1313from vllm .triton_utils import triton
1414
1515
16- def cal_diff (x : torch .Tensor , y : torch .Tensor , name : str ) -> None :
16+ def cal_diff (x : torch .Tensor ,
17+ y : torch .Tensor ,
18+ name : str ,
19+ use_fp8 : bool = False ) -> None :
1720 x , y = x .double (), y .double ()
1821 cos_diff = 1 - 2 * (x * y ).sum ().item () / max (
1922 (x * x + y * y ).sum ().item (), 1e-12 )
20- assert cos_diff < 1e-5
23+ if (use_fp8 ):
24+ assert cos_diff < 1e-4
25+ else :
26+ assert cos_diff < 1e-5
2127
2228FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported ()[1 ] \
2329 if not is_flashmla_supported ()[0 ] else "FlashMLA is supported"
@@ -27,28 +33,34 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
2733 reason = FLASH_MLA_UNSUPPORTED_REASON )
2834@pytest .mark .parametrize ("b" , [128 ])
2935@pytest .mark .parametrize ("s_q" , [1 , 2 ])
30- @pytest .mark .parametrize ("mean_sk" , [4096 , 8192 ])
36+ @pytest .mark .parametrize ("mean_sk" , [4096 , 8192 , 16384 ])
3137@pytest .mark .parametrize ("h_q" , [16 , 32 , 64 , 128 ])
3238@pytest .mark .parametrize ("h_kv" , [1 ])
3339@pytest .mark .parametrize ("d" , [576 ])
3440@pytest .mark .parametrize ("dv" , [512 ])
3541@pytest .mark .parametrize ("block_size" , [64 ])
3642@pytest .mark .parametrize ("causal" , [True ])
3743@pytest .mark .parametrize ("varlen" , [False , True ])
38- @pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
44+ @pytest .mark .parametrize ("torch_dtype" ,
45+ [torch .bfloat16 , torch .float16 , torch .float8_e4m3fn ])
3946@torch .inference_mode ()
4047def test_flash_mla (b , s_q , mean_sk , h_q , h_kv , d , dv , block_size , causal ,
41- varlen , dtype ):
48+ varlen , torch_dtype ):
4249 device = torch .device ("cuda:0" )
43- torch .set_default_dtype (dtype )
50+ if torch_dtype == torch .float8_e4m3fn :
51+ init_dtype = torch .bfloat16
52+ else :
53+ init_dtype = torch_dtype
54+ torch .set_default_dtype (init_dtype )
4455 torch .set_default_device (device )
4556 torch .cuda .set_device (device )
4657 torch .manual_seed (0 )
4758 random .seed (0 )
4859
4960 print (f"{ b = } , { s_q = } , { mean_sk = } , { h_q = } , { h_kv = } , "
50- f"{ d = } , { dv = } , { causal = } , { varlen = } , { dtype = } " )
61+ f"{ d = } , { dv = } , { causal = } , { varlen = } , { torch_dtype = } " )
5162
63+ use_fp8 = torch_dtype == torch .float8_e4m3fn
5264 cache_seqlens = torch .full ((b , ), mean_sk , dtype = torch .int32 )
5365 if varlen :
5466 for i in range (b ):
@@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
7183 tile_scheduler_metadata , num_splits = get_mla_metadata (
7284 cache_seqlens , s_q * h_q // h_kv , h_kv )
7385
86+ init_dtype = q .dtype
87+ if use_fp8 :
88+ fp8_dtype = torch .float8_e4m3fn
89+ descale_q = torch .ones ((1 ), dtype = torch .float32 )
90+ descale_k = torch .ones ((1 ), dtype = torch .float32 )
91+
92+ q = q .to (fp8_dtype )
93+ blocked_k = blocked_k .to (fp8_dtype )
94+ blocked_v = blocked_v .to (fp8_dtype )
95+ else :
96+ descale_q = None
97+ descale_k = None
98+
7499 def flash_mla ():
75100 return flash_mla_with_kvcache (
76101 q ,
@@ -81,6 +106,8 @@ def flash_mla():
81106 tile_scheduler_metadata ,
82107 num_splits ,
83108 causal = causal ,
109+ descale_q = descale_q ,
110+ descale_k = descale_k ,
84111 )
85112
86113 def scaled_dot_product_attention (query , key , value , is_causal = False ):
@@ -104,29 +131,35 @@ def scaled_dot_product_attention(query, key, value, is_causal=False):
104131 return attn_weight @ value , lse
105132
106133 def ref_mla ():
134+ q_ = (q .to (torch .float ) * descale_q ).to (init_dtype ) if use_fp8 else q
135+ blocked_k_ = (blocked_k .to (torch .float ) *
136+ descale_k ).to (init_dtype ) if use_fp8 else blocked_k
137+ blocked_v_ = (blocked_v .to (torch .float ) *
138+ descale_k ).to (init_dtype ) if use_fp8 else blocked_v
107139 out = torch .empty (b , s_q , h_q , dv , dtype = torch .float32 )
108140 lse = torch .empty (b , h_q , s_q , dtype = torch .float32 )
109141 for i in range (b ):
110142 begin = i * max_seqlen_pad
111143 end = begin + cache_seqlens [i ]
112- ref_O , LSE = scaled_dot_product_attention (
113- q [i ].transpose (0 , 1 ),
114- blocked_k .view (- 1 , h_kv , d )[begin :end ].transpose (0 , 1 ),
115- blocked_v .view (- 1 , h_kv , dv )[begin :end ].transpose (0 , 1 ),
144+ out_i , lse_i = scaled_dot_product_attention (
145+ q_ [i ].transpose (0 , 1 ),
146+ blocked_k_ .view (- 1 , h_kv , d )[begin :end ].transpose (0 , 1 ),
147+ blocked_v_ .view (- 1 , h_kv , dv )[begin :end ].transpose (0 , 1 ),
116148 is_causal = causal ,
117149 )
118- out [i ] = ref_O .transpose (0 , 1 )
119- lse [i ] = LSE
150+ out [i ] = out_i .transpose (0 , 1 )
151+ lse [i ] = lse_i
120152 return out , lse
121153
122154 out_flash , lse_flash = flash_mla ()
123155 out_torch , lse_torch = ref_mla ()
124- cal_diff (out_flash , out_torch , "out" )
156+ cal_diff (out_flash , out_torch , "out" , use_fp8 )
125157 cal_diff (lse_flash , lse_torch , "lse" )
126158
127159 t = triton .testing .do_bench (flash_mla )
128160 FLOPS = s_q * total_seqlens * h_q * (d + dv ) * 2
129- bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
130- b * s_q * h_q * dv ) * (torch .finfo (dtype ).bits // 8 )
131- print (f"{ t :.3f} ms, { FLOPS / 10 ** 9 / t :.0f} "
132- f"TFLOPS, { bytes / 10 ** 6 / t :.0f} GB/s" )
161+ bytes = (total_seqlens * h_kv * d +
162+ b * s_q * h_q * d ) * (torch .finfo (torch_dtype ).bits // 8 ) + (
163+ b * s_q * h_q * dv ) * (torch .finfo (init_dtype ).bits // 8 )
164+ print (f"{ t :.3f} ms, { FLOPS / 10 ** 9 / t :.0f} TFLOPS," ,
165+ f"{ bytes / 10 ** 6 / t :.0f} GB/s" )
0 commit comments