diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 21b08e45fd6f..81841be58352 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -35,11 +35,10 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @torch.inference_mode() def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, - varlen): - # TODO: parametrize using pytest - dtype = torch.bfloat16 + varlen, dtype): device = torch.device("cuda:0") torch.set_default_dtype(dtype) torch.set_default_device(device) @@ -48,7 +47,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, random.seed(0) print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " - f"{d=}, {dv=}, {causal=}, {varlen=}") + f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}") cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) if varlen: