@@ -35,11 +35,10 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
3535@pytest .mark .parametrize ("block_size" , [64 ])
3636@pytest .mark .parametrize ("causal" , [True ])
3737@pytest .mark .parametrize ("varlen" , [False , True ])
38+ @pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
3839@torch .inference_mode ()
3940def test_flash_mla (b , s_q , mean_sk , h_q , h_kv , d , dv , block_size , causal ,
40- varlen ):
41- # TODO: parametrize using pytest
42- dtype = torch .bfloat16
41+ varlen , dtype ):
4342 device = torch .device ("cuda:0" )
4443 torch .set_default_dtype (dtype )
4544 torch .set_default_device (device )
@@ -48,7 +47,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
4847 random .seed (0 )
4948
5049 print (f"{ b = } , { s_q = } , { mean_sk = } , { h_q = } , { h_kv = } , "
51- f"{ d = } , { dv = } , { causal = } , { varlen = } " )
50+ f"{ d = } , { dv = } , { causal = } , { varlen = } , { dtype = } " )
5251
5352 cache_seqlens = torch .full ((b , ), mean_sk , dtype = torch .int32 )
5453 if varlen :
0 commit comments