Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/kernels/quantization/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
DTYPES = [torch.bfloat16, torch.float]
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
NUM_TOKENS = [1, 7, 4096]
SCALE_UBS = [True, False]
SEEDS = [0]

Expand Down
7 changes: 3 additions & 4 deletions tests/kernels/quantization/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from vllm._custom_ops import scaled_int8_quant
from vllm.platforms import current_platform

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
DTYPES = [torch.bfloat16, torch.float]
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
NUM_TOKENS = [1, 7, 4096]
SEEDS = [0]
SCALE = [0.1, 2.1]

Expand Down
4 changes: 0 additions & 4 deletions tests/kernels/quantization/test_machete_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,13 @@

MNK_SHAPES = [
(1, 128, 128),
(1, 512, 1024),
(1, 4096, 4096),
(1, 8192, 28672),
(13, 8192, 4096),
(26, 4096, 8192),
(64, 4096, 4096),
(64, 8192, 28672),
(257, 128, 4096),
(257, 4224, 4160),
(257, 4096, 4096),
(1024, 4096, 8192),
(1024, 8192, 4096),
]

Expand Down
4 changes: 0 additions & 4 deletions tests/kernels/quantization/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,8 @@
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]

DTYPES = [torch.float16, torch.bfloat16]
Expand Down
60 changes: 48 additions & 12 deletions tests/kernels/quantization/test_rocm_skinny_gemms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,55 @@
from vllm.platforms import current_platform

DTYPES = [torch.bfloat16, torch.float16]
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0
N = [1, 2, 3, 4]
# Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1 = [
# Small, medium, large cases
(1, 8, 16),
(1, 32, 64),
(1, 128, 256),
(1, 512, 1024),
(1, 2048, 4096),
# Edge cases with specific K sizes
(1, 6144, 1024),
(1, 8192, 2048),
# Very large case
(1, 4096, 8192),
]

NKM_FACTORS_WVSPLITK = [
# Different batch sizes with key dimensions
(1, 16, 16),
(1, 64, 64),
(2, 256, 256),
(3, 1024, 1024),
(4, 4096, 4096),
# Extended K values
(1, 9216, 512),
(2, 10240, 1024),
(4, 16384, 8192),
# Minimum M constraint validation (m >= 8)
(1, 64, 8),
(2, 128, 8),
(4, 256, 8),
]

NKM_FACTORS_WVSPLITK_FP8 = [
# FP8-specific cases with K % 16 == 0
(1, 16, 16),
(1, 64, 64),
(2, 512, 512),
(3, 2048, 2048),
(4, 4096, 4096),
# Extended FP8 dimensions not covered by WVSPLITK
(1, 14336, 1024),
(2, 24576, 2048),
(4, 32768, 28672),
]

SEEDS = [0]


@pytest.mark.parametrize("n", [1]) # only test for batch size 1
@pytest.mark.parametrize("k", K)
@pytest.mark.parametrize("m", M)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
Expand All @@ -34,9 +74,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
assert torch.allclose(out, ref_out, rtol=0.01)


@pytest.mark.parametrize("n", N) # only test for batch size <= 4
@pytest.mark.parametrize("k", K + [9216, 10240, 16384])
@pytest.mark.parametrize("m", [8] + M) # m >= 8
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
Expand All @@ -54,9 +92,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
assert torch.allclose(out, ref_out, rtol=0.01)


@pytest.mark.parametrize("n", N) # only test for batch size <= 4
@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(
Expand Down
16 changes: 12 additions & 4 deletions tests/kernels/quantization/test_triton_scaled_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,18 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
num_logprobs)


@pytest.mark.parametrize("M", [1, 33, 64, 512])
@pytest.mark.parametrize("N", [256, 971, 20486])
@pytest.mark.parametrize("K", [128, 496, 1024])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
MNK_FACTORS = [
(1, 256, 128),
(33, 256, 496),
(64, 971, 1024),
(64, 20486, 128),
(512, 256, 496),
(512, 20486, 1024),
]


@pytest.mark.parametrize("M,N,K", MNK_FACTORS)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
@pytest.mark.parametrize("in_dtype", get_8bit_types())
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
Expand Down