diff --git a/tests/kernels/quantization/test_fp8_quant.py b/tests/kernels/quantization/test_fp8_quant.py index 0a3edd4ddc16..c2e70ffb8d34 100644 --- a/tests/kernels/quantization/test_fp8_quant.py +++ b/tests/kernels/quantization/test_fp8_quant.py @@ -11,11 +11,9 @@ from tests.kernels.utils import opcheck from vllm.platforms import current_platform -DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, - 8193] # Arbitrary values for testing -HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases -NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +DTYPES = [torch.bfloat16, torch.float] +HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193] +NUM_TOKENS = [1, 7, 4096] SCALE_UBS = [True, False] SEEDS = [0] diff --git a/tests/kernels/quantization/test_int8_quant.py b/tests/kernels/quantization/test_int8_quant.py index 5a37b976db9e..c1c9bf191d5b 100644 --- a/tests/kernels/quantization/test_int8_quant.py +++ b/tests/kernels/quantization/test_int8_quant.py @@ -9,10 +9,9 @@ from vllm._custom_ops import scaled_int8_quant from vllm.platforms import current_platform -DTYPES = [torch.half, torch.bfloat16, torch.float] -HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing -HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases -NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +DTYPES = [torch.bfloat16, torch.float] +HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193] +NUM_TOKENS = [1, 7, 4096] SEEDS = [0] SCALE = [0.1, 2.1] diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index a7cb2a4e7f21..a842d2f1cbe8 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -34,8 +34,6 @@ MNK_SHAPES = [ (1, 128, 128), - (1, 512, 1024), - (1, 4096, 4096), (1, 8192, 28672), (13, 8192, 4096), (26, 4096, 8192), @@ -43,8 +41,6 @@ (64, 8192, 28672), (257, 128, 4096), (257, 4224, 4160), - (257, 4096, 4096), - (1024, 4096, 8192), (1024, 8192, 4096), ] diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 1bd6713ce7fb..cea7700ac329 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -53,12 +53,8 @@ MNK_FACTORS = [ (1, 1, 1), (1, 4, 8), - (1, 7, 5), - (13, 17, 67), (26, 37, 13), - (67, 13, 11), (257, 13, 11), - (658, 13, 11), ] DTYPES = [torch.float16, torch.bfloat16] diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 533a4fe59677..03d5d98739c5 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -8,15 +8,55 @@ from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] -M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] -K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0 -N = [1, 2, 3, 4] +# Specific (N, K, M) combinations for targeted testing +NKM_FACTORS_LLMM1 = [ + # Small, medium, large cases + (1, 8, 16), + (1, 32, 64), + (1, 128, 256), + (1, 512, 1024), + (1, 2048, 4096), + # Edge cases with specific K sizes + (1, 6144, 1024), + (1, 8192, 2048), + # Very large case + (1, 4096, 8192), +] + +NKM_FACTORS_WVSPLITK = [ + # Different batch sizes with key dimensions + (1, 16, 16), + (1, 64, 64), + (2, 256, 256), + (3, 1024, 1024), + (4, 4096, 4096), + # Extended K values + (1, 9216, 512), + (2, 10240, 1024), + (4, 16384, 8192), + # Minimum M constraint validation (m >= 8) + (1, 64, 8), + (2, 128, 8), + (4, 256, 8), +] + +NKM_FACTORS_WVSPLITK_FP8 = [ + # FP8-specific cases with K % 16 == 0 + (1, 16, 16), + (1, 64, 64), + (2, 512, 512), + (3, 2048, 2048), + (4, 4096, 4096), + # Extended FP8 dimensions not covered by WVSPLITK + (1, 14336, 1024), + (2, 24576, 2048), + (4, 32768, 28672), +] + SEEDS = [0] -@pytest.mark.parametrize("n", [1]) # only test for batch size 1 -@pytest.mark.parametrize("k", K) -@pytest.mark.parametrize("m", M) +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) @pytest.mark.parametrize("seed", SEEDS) @@ -34,9 +74,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): assert torch.allclose(out, ref_out, rtol=0.01) -@pytest.mark.parametrize("n", N) # only test for batch size <= 4 -@pytest.mark.parametrize("k", K + [9216, 10240, 16384]) -@pytest.mark.parametrize("m", [8] + M) # m >= 8 +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not current_platform.is_rocm(), @@ -54,9 +92,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): assert torch.allclose(out, ref_out, rtol=0.01) -@pytest.mark.parametrize("n", N) # only test for batch size <= 4 -@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0 -@pytest.mark.parametrize("m", M + [28672]) # m >= 16 +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif( diff --git a/tests/kernels/quantization/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py index 8a2cc3baced2..24245663fb1d 100644 --- a/tests/kernels/quantization/test_triton_scaled_mm.py +++ b/tests/kernels/quantization/test_triton_scaled_mm.py @@ -60,10 +60,18 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, num_logprobs) -@pytest.mark.parametrize("M", [1, 33, 64, 512]) -@pytest.mark.parametrize("N", [256, 971, 20486]) -@pytest.mark.parametrize("K", [128, 496, 1024]) -@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) +MNK_FACTORS = [ + (1, 256, 128), + (33, 256, 496), + (64, 971, 1024), + (64, 20486, 128), + (512, 256, 496), + (512, 20486, 1024), +] + + +@pytest.mark.parametrize("M,N,K", MNK_FACTORS) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) @pytest.mark.parametrize("in_dtype", get_8bit_types()) @pytest.mark.parametrize("use_scalar_scale_a", [True, False]) @pytest.mark.parametrize("use_scalar_scale_b", [True, False])