From 492c6db24af68a9eb4dc85297c432b070c3c0f13 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 25 Mar 2025 03:02:12 +0000 Subject: [PATCH 1/9] add AITER int8 a8w8 gemm kernel Signed-off-by: tjtanaa --- vllm/_custom_ops.py | 32 +++++++++++++++++++++++++++----- vllm/envs.py | 8 ++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d68c097fbe84..575d591c8f46 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -36,6 +36,12 @@ def register_fake(fn): from torch.library import impl_abstract as register_fake +def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_LINEAR \ + and envs.VLLM_ROCM_USE_AITER + + # page attention ops def paged_attention_v1( out: torch.Tensor, @@ -529,11 +535,27 @@ def cutlass_scaled_mm(a: torch.Tensor, n = b.shape[1] if current_platform.is_rocm(): - triton_scaled_mm_module = importlib.import_module( - "vllm.model_executor.layers.quantization.compressed_tensors." - "triton_scaled_mm") - triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + if is_rocm_aiter_gemm_w8a8_scaled_mm_enabled(): + per_tensor_scale_a = (scale_a.numel() == 1) + per_tensor_scale_b = (scale_b.numel() == 1) + per_channel_tensor_scale_a = (scale_a.numel() == m) + per_channel_tensor_scale_b = (scale_b.numel() == n) + + assert ( + (per_tensor_scale_a and per_tensor_scale_b) or + (per_channel_tensor_scale_a and per_channel_tensor_scale_b)), ( + "Currently only support per-tensor or per-channel" + + " scaled w8a8 gemm. `cutlass_scaled_mm` does not support" + + " ATIER block scaled GEMM yet.") + + from aiter import gemm_a8w8_CK + return gemm_a8w8_CK(a, b.t(), scale_a, scale_b, bias).to(out_dtype) + else: + triton_scaled_mm_module = importlib.import_module( + "vllm.model_executor.layers.quantization.compressed_tensors." + "triton_scaled_mm") + triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/vllm/envs.py b/vllm/envs.py index 829f9450fb77..0c60cc92fe3e 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -72,6 +72,7 @@ VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True @@ -510,6 +511,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1")), + # use aiter linear op if aiter ops are enabled + # The following list of related ops + # - scaled_mm (per-tensor / rowwise) + "VLLM_ROCM_USE_AITER_LINEAR": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in + ("true", "1")), + # use aiter rms norm op if aiter ops are enabled. "VLLM_ROCM_USE_AITER_RMSNORM": lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in From 895d6ba1547657dee0cc1e4de592c07a83cf7308 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 25 Mar 2025 13:22:23 +0000 Subject: [PATCH 2/9] enable compressed tensors for AITER and ROCm Signed-off-by: tjtanaa --- tests/quantization/test_compressed_tensors.py | 135 +++++++++++++++++- vllm/_custom_ops.py | 11 +- 2 files changed, 141 insertions(+), 5 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 133475a3e06a..42c0505e080b 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -20,6 +20,20 @@ sparse_cutlass_supported) from vllm.platforms import current_platform +ROCM_AITER_SUPPORTED_INT8_MODEL = [ + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2" +] + +# TritonScaledMMLinearKernel only supports symmetric quantization. +ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [ + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", +] + @pytest.fixture(scope="function", autouse=True) def use_v0_only(monkeypatch): @@ -57,6 +71,11 @@ def use_v0_only(monkeypatch): ) def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): model_path, strategy, quant_type, shape_0, is_symmetric = model_args + + if current_platform.is_rocm( + ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") + with vllm_runner(model_path, enforce_eager=True) as llm: def check_model(model): @@ -131,6 +150,11 @@ def test_compressed_tensors_w8a8_logprobs( max_tokens, num_logprobs, ): + + if current_platform.is_rocm( + ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") + dtype = "bfloat16" # skip language translation prompt for the static per tensor asym model @@ -154,6 +178,9 @@ def test_compressed_tensors_w8a8_logprobs( name_1="vllm", ) + if current_platform.is_rocm(): + torch.cuda.synchronize() + def test_compressed_tensors_no_enforce_eager(vllm_runner): model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" @@ -179,6 +206,11 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner): ) def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args): model_path, strategy = model_args + + if current_platform.is_rocm( + ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") + with vllm_runner(model_path, dtype=torch.float16) as llm: def check_model(model): @@ -207,6 +239,8 @@ def check_model(model): ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4), ], ) +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="The tests are skipped on non-CUDA platform.") def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): model, strategy, group, pack_factor = wNa16_args with vllm_runner(model) as llm: @@ -231,6 +265,8 @@ def check_model(model): assert output +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") def test_compressed_tensors_w4a16_marlin24(vllm_runner): model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" with vllm_runner(model_path) as llm: @@ -271,7 +307,9 @@ def check_model(model): if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8): assert len(qkv_proj.input_scale.shape) == 0 - assert qkv_proj.weight.dtype is torch.float8_e4m3fn + assert qkv_proj.weight.dtype is (torch.float8_e4m3fnuz + if current_platform.is_rocm() + else torch.float8_e4m3fn) assert qkv_proj.weight_scale.dtype is torch.float32 assert len(qkv_proj.weight_scale.shape) == 0 @@ -281,6 +319,8 @@ def check_model(model): assert output +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") def test_compressed_tensors_kv_cache(vllm_runner): model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: @@ -309,7 +349,8 @@ def _test_2of4_quant_models(qkv_proj, @pytest.mark.skipif( - not current_platform.has_device_capability(90), + not current_platform.is_cuda() + or not current_platform.has_device_capability(90), reason="Sparse FP8 is not yet supported on this GPU type.", ) @pytest.mark.parametrize( @@ -356,7 +397,8 @@ def check_model(model): @pytest.mark.skipif( - not current_platform.has_device_capability(90), + not current_platform.is_cuda() + or not current_platform.has_device_capability(90), reason="Sparse FP8 is not yet supported on this GPU type.", ) @pytest.mark.parametrize( @@ -571,3 +613,90 @@ def check_model(model): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +@pytest.mark.parametrize( + "model_path", + [ + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + ], +) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [10]) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="This tests is skipped on non-ROCm platform.") +def test_compressed_tensors_w8a8_logprobs_rocm_aiter( + hf_runner, + vllm_runner, + example_prompts, + model_path, + max_tokens, + num_logprobs, + monkeypatch, +): + # this will enable VLLM_ROCM_USE_AITER_LINEAR + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + + dtype = "bfloat16" + + # skip language translation prompt for the static per tensor asym model + if (model_path == + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" + ): # noqa: E501 + example_prompts = example_prompts[0:-1] + + with hf_runner(model_path, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_path, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize( + "model_args", + [ + ( + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", + "channel", + ), + ], +) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="This tests is skipped on non-ROCm platform.") +def test_compressed_tensors_w8a8_dynamic_per_token_rocm_aiter( + vllm_runner, + model_args, + monkeypatch, +): + + # this will enable VLLM_ROCM_USE_AITER_LINEAR + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + + model_path, strategy = model_args + with vllm_runner(model_path, dtype=torch.float16) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8) + assert not qkv_proj.scheme.is_static_input_scheme + assert qkv_proj.scheme.strategy == strategy + assert qkv_proj.weight.dtype is torch.int8 + + llm.apply_model(check_model) + + output = llm.generate_greedy(["Hello my name is"], max_tokens=20) + assert output diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 575d591c8f46..8fea4cd1e43c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -541,11 +541,18 @@ def cutlass_scaled_mm(a: torch.Tensor, per_channel_tensor_scale_a = (scale_a.numel() == m) per_channel_tensor_scale_b = (scale_b.numel() == n) + # @TODO: + # Maybe broadcast the per-tensor-scale into per-channel-scale + # if one of the scale is a per-channel-scale. + # For now, it only supports + # per-tensor-per-tensor a8w8 scaled GEMM and + # per-channel-per-channel a8w8 scacled GEMM assert ( (per_tensor_scale_a and per_tensor_scale_b) or (per_channel_tensor_scale_a and per_channel_tensor_scale_b)), ( - "Currently only support per-tensor or per-channel" + - " scaled w8a8 gemm. `cutlass_scaled_mm` does not support" + + "Currently only support per-tensor-per-tensor GEMM " + + " and per-channel-per-channel GEMM through AITER" + " w8a8 scaled gemm. `cutlass_scaled_mm` does not support" + " ATIER block scaled GEMM yet.") from aiter import gemm_a8w8_CK From a5a25a3fea6f5494af57cdaa0fd437e79d3f6735 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Wed, 26 Mar 2025 06:16:40 +0000 Subject: [PATCH 3/9] Add AiterScaledMMKernel abstraction Signed-off-by: tjtanaa --- .../compressed_tensors/triton_scaled_mm.py | 10 ++-- .../kernels/scaled_mm/__init__.py | 4 +- .../quantization/kernels/scaled_mm/aiter.py | 58 +++++++++++++++++++ 3 files changed, 67 insertions(+), 5 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index b69c5e7a02a7..302172c0e58d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -143,10 +143,12 @@ def triton_scaled_mm(input: torch.Tensor, scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() - assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( - [M, 1]) - assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( - [N, 1]) + assert scale_a.shape == (1, 1) or scale_a.shape == (M, 1) + assert scale_b.shape == (1, 1) or scale_b.shape == (N, 1) + # assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( + # [M, 1]) + # assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( + # [N, 1]) assert out_dtype.is_floating_point assert bias is None or bias.is_floating_point() assert is_weak_contiguous(input) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index a5967995ac88..bedda4c2ab21 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -3,6 +3,8 @@ import os from typing import Dict, List, Optional, Type +from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( + AiterScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 @@ -17,7 +19,7 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CutlassScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], - PlatformEnum.ROCM: [TritonScaledMMLinearKernel], + PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py new file mode 100644 index 000000000000..270a0b3e72da --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig + + +class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if current_platform.is_cpu(): + return ( + False, + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "currently supported on CPU.") + if not current_platform.is_rocm(): + return ( + False, + "AiterScaledMMLinearKernel requires `aiter` which is only " + + "currently supported on ROCm.") + # try import aiter + try: + pass + except Exception: + return ( + False, + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "installed supported on ROCm.") + if not ops.is_rocm_aiter_gemm_w8a8_scaled_mm_enabled(): + return (False, "AiterScaledMMLinearKernel is disabled. " + + "Enable by setting `VLLM_ROCM_USE_AITER=1`.") + + if not c.input_symmetric: + return (False, + "AiterScaledMMLinearKernel only supports symmetric " + + "quantization.") + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return super().apply_weights(layer, x, bias) From caf94eebf66e658d60046000d183df51736e09cc Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Thu, 27 Mar 2025 16:07:26 +0000 Subject: [PATCH 4/9] extract aiter ops into AITERScaledMMKernel; update unittests Signed-off-by: tjtanaa --- tests/quantization/test_compressed_tensors.py | 118 ++++-------------- vllm/_custom_ops.py | 39 +----- .../compressed_tensors/triton_scaled_mm.py | 10 +- .../quantization/kernels/scaled_mm/aiter.py | 75 +++++++++-- 4 files changed, 102 insertions(+), 140 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 42c0505e080b..83f4b2e0f37d 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -142,6 +142,8 @@ def zp_valid(zp: Optional[torch.Tensor]): ) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [10]) +@pytest.mark.parametrize( + "use_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_compressed_tensors_w8a8_logprobs( hf_runner, vllm_runner, @@ -149,12 +151,21 @@ def test_compressed_tensors_w8a8_logprobs( model_path, max_tokens, num_logprobs, + use_aiter, + monkeypatch, ): if current_platform.is_rocm( ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") + if use_aiter: + if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL: + pytest.skip( + f"Skip model {model_path} as it is not support by aiter.") + # this will enable VLLM_ROCM_USE_AITER_LINEAR + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + dtype = "bfloat16" # skip language translation prompt for the static per tensor asym model @@ -204,13 +215,27 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner): ), ], ) -def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args): +@pytest.mark.parametrize( + "use_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_compressed_tensors_w8a8_dynamic_per_token( + vllm_runner, + model_args, + use_aiter, + monkeypatch, +): model_path, strategy = model_args if current_platform.is_rocm( ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") + if use_aiter: + if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL: + pytest.skip( + f"Skip model {model_path} as it is not support by aiter.") + # this will enable VLLM_ROCM_USE_AITER_LINEAR + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner(model_path, dtype=torch.float16) as llm: def check_model(model): @@ -307,9 +332,7 @@ def check_model(model): if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8): assert len(qkv_proj.input_scale.shape) == 0 - assert qkv_proj.weight.dtype is (torch.float8_e4m3fnuz - if current_platform.is_rocm() - else torch.float8_e4m3fn) + assert qkv_proj.weight.dtype is current_platform.fp8_dtype() assert qkv_proj.weight_scale.dtype is torch.float32 assert len(qkv_proj.weight_scale.shape) == 0 @@ -613,90 +636,3 @@ def check_model(model): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output - - -@pytest.mark.parametrize( - "model_path", - [ - "neuralmagic/Llama-3.2-1B-quantized.w8a8", - ], -) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("num_logprobs", [10]) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="This tests is skipped on non-ROCm platform.") -def test_compressed_tensors_w8a8_logprobs_rocm_aiter( - hf_runner, - vllm_runner, - example_prompts, - model_path, - max_tokens, - num_logprobs, - monkeypatch, -): - # this will enable VLLM_ROCM_USE_AITER_LINEAR - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - - dtype = "bfloat16" - - # skip language translation prompt for the static per tensor asym model - if (model_path == - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" - ): # noqa: E501 - example_prompts = example_prompts[0:-1] - - with hf_runner(model_path, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - - with vllm_runner(model_path, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize( - "model_args", - [ - ( - "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", - "channel", - ), - ], -) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="This tests is skipped on non-ROCm platform.") -def test_compressed_tensors_w8a8_dynamic_per_token_rocm_aiter( - vllm_runner, - model_args, - monkeypatch, -): - - # this will enable VLLM_ROCM_USE_AITER_LINEAR - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - - model_path, strategy = model_args - with vllm_runner(model_path, dtype=torch.float16) as llm: - - def check_model(model): - layer = model.model.layers[0] - - qkv_proj = layer.self_attn.qkv_proj - - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8) - assert not qkv_proj.scheme.is_static_input_scheme - assert qkv_proj.scheme.strategy == strategy - assert qkv_proj.weight.dtype is torch.int8 - - llm.apply_model(check_model) - - output = llm.generate_greedy(["Hello my name is"], max_tokens=20) - assert output diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e866d8573861..dc07bad4680f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -36,12 +36,6 @@ def register_fake(fn): from torch.library import impl_abstract as register_fake -def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_LINEAR \ - and envs.VLLM_ROCM_USE_AITER - - # page attention ops def paged_attention_v1( out: torch.Tensor, @@ -547,34 +541,11 @@ def cutlass_scaled_mm(a: torch.Tensor, n = b.shape[1] if current_platform.is_rocm(): - if is_rocm_aiter_gemm_w8a8_scaled_mm_enabled(): - per_tensor_scale_a = (scale_a.numel() == 1) - per_tensor_scale_b = (scale_b.numel() == 1) - per_channel_tensor_scale_a = (scale_a.numel() == m) - per_channel_tensor_scale_b = (scale_b.numel() == n) - - # @TODO: - # Maybe broadcast the per-tensor-scale into per-channel-scale - # if one of the scale is a per-channel-scale. - # For now, it only supports - # per-tensor-per-tensor a8w8 scaled GEMM and - # per-channel-per-channel a8w8 scacled GEMM - assert ( - (per_tensor_scale_a and per_tensor_scale_b) or - (per_channel_tensor_scale_a and per_channel_tensor_scale_b)), ( - "Currently only support per-tensor-per-tensor GEMM " + - " and per-channel-per-channel GEMM through AITER" - " w8a8 scaled gemm. `cutlass_scaled_mm` does not support" + - " ATIER block scaled GEMM yet.") - - from aiter import gemm_a8w8_CK - return gemm_a8w8_CK(a, b.t(), scale_a, scale_b, bias).to(out_dtype) - else: - triton_scaled_mm_module = importlib.import_module( - "vllm.model_executor.layers.quantization.compressed_tensors." - "triton_scaled_mm") - triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + triton_scaled_mm_module = importlib.import_module( + "vllm.model_executor.layers.quantization.compressed_tensors." + "triton_scaled_mm") + triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index 302172c0e58d..b69c5e7a02a7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -143,12 +143,10 @@ def triton_scaled_mm(input: torch.Tensor, scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() - assert scale_a.shape == (1, 1) or scale_a.shape == (M, 1) - assert scale_b.shape == (1, 1) or scale_b.shape == (N, 1) - # assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( - # [M, 1]) - # assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( - # [N, 1]) + assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( + [M, 1]) + assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( + [N, 1]) assert out_dtype.is_floating_point assert bias is None or bias.is_floating_point() assert is_weak_contiguous(input) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 270a0b3e72da..c79c815713b9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.platforms import current_platform @@ -11,6 +12,12 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig +def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_LINEAR \ + and envs.VLLM_ROCM_USE_AITER + + class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod @@ -20,16 +27,12 @@ def get_min_capability(cls) -> int: @classmethod def can_implement( cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: - if current_platform.is_cpu(): + if current_platform.is_cpu() or not current_platform.is_rocm(): return ( False, "AiterScaledMMLinearKernel requires `aiter` which is not " + - "currently supported on CPU.") - if not current_platform.is_rocm(): - return ( - False, - "AiterScaledMMLinearKernel requires `aiter` which is only " + - "currently supported on ROCm.") + "currently supported on CPU and non-ROCm platform.") + # try import aiter try: pass @@ -38,7 +41,7 @@ def can_implement( False, "AiterScaledMMLinearKernel requires `aiter` which is not " + "installed supported on ROCm.") - if not ops.is_rocm_aiter_gemm_w8a8_scaled_mm_enabled(): + if not is_rocm_aiter_gemm_w8a8_scaled_mm_enabled(): return (False, "AiterScaledMMLinearKernel is disabled. " + "Enable by setting `VLLM_ROCM_USE_AITER=1`.") @@ -55,4 +58,58 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return super().apply_weights(layer, x, bias) + """ + `AiterScaledMMLinearKernel` implements a fused version of + `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` + where scale_a * a and scale_b * b are implemented using numpy-style + broadcasting. + Currently only support per-tensor-per-tensor GEMM + and per-channel-per-channel GEMM through AITER + w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support + ATIER block scaled GEMM and mix-precision GEMM. + """ + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + symmetric = azp_adj is None + assert symmetric, ("AiterScaledMMLinearKernel only supports" + " symmetric quantization.") + x_q, x_s, x_zp = ops.scaled_int8_quant(x, + i_s, + i_zp, + symmetric=symmetric) + + assert x_zp is None, ("AiterScaledMMLinearKernel only supports" + " symmetric quantization.") + out_dtype = x.dtype + + assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == w_q.shape[ + 1] and bias.dtype == out_dtype + + m = x_q.shape[0] # a + n = w_q.shape[1] # b + + per_tensor_scale_a = (x_s.numel() == 1) + per_tensor_scale_b = (w_s.numel() == 1) + per_channel_tensor_scale_a = (x_s.numel() == m) + per_channel_tensor_scale_b = (w_s.numel() == n) + + # @TODO: + # Maybe broadcast the per-tensor-scale into per-channel-scale + # if one of the scale is a per-channel-scale. + # For now, it only supports + # per-tensor-per-tensor a8w8 scaled GEMM and + # per-channel-per-channel a8w8 scacled GEMM + assert ((per_tensor_scale_a and per_tensor_scale_b) or + (per_channel_tensor_scale_a and per_channel_tensor_scale_b)), ( + "Currently only support per-tensor-per-tensor GEMM " + + " and per-channel-per-channel GEMM through AITER" + " w8a8 scaled gemm. `cutlass_scaled_mm` does not support" + + " ATIER block scaled GEMM yet.") + + from aiter import gemm_a8w8_CK + return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) From a26b31c6ab009bd408075ea929f8e8db12740011 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Thu, 27 Mar 2025 16:18:49 +0000 Subject: [PATCH 5/9] annotate import Signed-off-by: tjtanaa --- .../layers/quantization/kernels/scaled_mm/aiter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index c79c815713b9..d868281a83e6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -33,9 +33,8 @@ def can_implement( "AiterScaledMMLinearKernel requires `aiter` which is not " + "currently supported on CPU and non-ROCm platform.") - # try import aiter try: - pass + import aiter # noqa: F401 # deliberately attempt to import aiter except Exception: return ( False, From 4d231f463d2b4cd859e8c4002ca0b8af9c5871b7 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 28 Mar 2025 08:51:33 +0000 Subject: [PATCH 6/9] update code and unittest documentation Signed-off-by: tjtanaa --- tests/quantization/test_compressed_tensors.py | 3 +++ .../quantization/kernels/scaled_mm/aiter.py | 25 +++++++++---------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 83f4b2e0f37d..5c928f27c10d 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -20,6 +20,9 @@ sparse_cutlass_supported) from vllm.platforms import current_platform +# AITER only supports per-channel-per-channel INT8 gemm +# and per-tensor-per-tensor INT8 GEMM. +# It does not support mix precision MM and mix quantization scheme. ROCM_AITER_SUPPORTED_INT8_MODEL = [ "neuralmagic/Llama-3.2-1B-quantized.w8a8", "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2" diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index d868281a83e6..566ce3645f42 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -12,12 +12,6 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig -def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_LINEAR \ - and envs.VLLM_ROCM_USE_AITER - - class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod @@ -27,11 +21,11 @@ def get_min_capability(cls) -> int: @classmethod def can_implement( cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: - if current_platform.is_cpu() or not current_platform.is_rocm(): + if not current_platform.is_rocm(): return ( False, "AiterScaledMMLinearKernel requires `aiter` which is not " + - "currently supported on CPU and non-ROCm platform.") + "currently supported on non-ROCm platform.") try: import aiter # noqa: F401 # deliberately attempt to import aiter @@ -40,7 +34,12 @@ def can_implement( False, "AiterScaledMMLinearKernel requires `aiter` which is not " + "installed supported on ROCm.") - if not is_rocm_aiter_gemm_w8a8_scaled_mm_enabled(): + # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled + if not ( + current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_LINEAR \ + and envs.VLLM_ROCM_USE_AITER + ): return (False, "AiterScaledMMLinearKernel is disabled. " + "Enable by setting `VLLM_ROCM_USE_AITER=1`.") @@ -94,8 +93,8 @@ def apply_weights(self, per_tensor_scale_a = (x_s.numel() == 1) per_tensor_scale_b = (w_s.numel() == 1) - per_channel_tensor_scale_a = (x_s.numel() == m) - per_channel_tensor_scale_b = (w_s.numel() == n) + per_token_scale_a = (x_s.numel() == m) + per_channel_scale_b = (w_s.numel() == n) # @TODO: # Maybe broadcast the per-tensor-scale into per-channel-scale @@ -103,8 +102,8 @@ def apply_weights(self, # For now, it only supports # per-tensor-per-tensor a8w8 scaled GEMM and # per-channel-per-channel a8w8 scacled GEMM - assert ((per_tensor_scale_a and per_tensor_scale_b) or - (per_channel_tensor_scale_a and per_channel_tensor_scale_b)), ( + assert ((per_tensor_scale_a and per_tensor_scale_b) + or (per_token_scale_a and per_channel_scale_b)), ( "Currently only support per-tensor-per-tensor GEMM " + " and per-channel-per-channel GEMM through AITER" " w8a8 scaled gemm. `cutlass_scaled_mm` does not support" + From ab524818df557c60c00053a9a855287a42dc3194 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 28 Mar 2025 17:49:44 +0000 Subject: [PATCH 7/9] add more comments Signed-off-by: tjtanaa --- .../layers/quantization/kernels/scaled_mm/aiter.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 566ce3645f42..cc814e0d0595 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -36,8 +36,7 @@ def can_implement( "installed supported on ROCm.") # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled if not ( - current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_LINEAR \ + envs.VLLM_ROCM_USE_AITER_LINEAR \ and envs.VLLM_ROCM_USE_AITER ): return (False, "AiterScaledMMLinearKernel is disabled. " + @@ -110,4 +109,8 @@ def apply_weights(self, " ATIER block scaled GEMM yet.") from aiter import gemm_a8w8_CK + + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects + # a to be [M, K] + # b to be [N, K] # cutlass prepare weights in [K, N] format return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) From 9d8139022958ae43877c220440561cd19ece95a0 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 28 Mar 2025 17:51:31 +0000 Subject: [PATCH 8/9] add more comments Signed-off-by: tjtanaa --- .../layers/quantization/kernels/scaled_mm/aiter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index cc814e0d0595..b215c3be277f 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -112,5 +112,6 @@ def apply_weights(self, # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # a to be [M, K] - # b to be [N, K] # cutlass prepare weights in [K, N] format + # b to be [N, K] + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) From 975492144c5f89f7bc477c74abfdb10b9718961f Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 28 Mar 2025 18:25:22 +0000 Subject: [PATCH 9/9] add more comments Signed-off-by: tjtanaa --- .../quantization/kernels/scaled_mm/aiter.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index b215c3be277f..582b12f76562 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -33,14 +33,16 @@ def can_implement( return ( False, "AiterScaledMMLinearKernel requires `aiter` which is not " + - "installed supported on ROCm.") + "installed on ROCm.") # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled if not ( envs.VLLM_ROCM_USE_AITER_LINEAR \ and envs.VLLM_ROCM_USE_AITER ): return (False, "AiterScaledMMLinearKernel is disabled. " + - "Enable by setting `VLLM_ROCM_USE_AITER=1`.") + "Enable by setting `VLLM_ROCM_USE_AITER=1` " + + "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.") if not c.input_symmetric: return (False, @@ -61,7 +63,7 @@ def apply_weights(self, where scale_a * a and scale_b * b are implemented using numpy-style broadcasting. Currently only support per-tensor-per-tensor GEMM - and per-channel-per-channel GEMM through AITER + and per-token-per-channel GEMM through AITER w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support ATIER block scaled GEMM and mix-precision GEMM. """ @@ -98,15 +100,15 @@ def apply_weights(self, # @TODO: # Maybe broadcast the per-tensor-scale into per-channel-scale # if one of the scale is a per-channel-scale. - # For now, it only supports - # per-tensor-per-tensor a8w8 scaled GEMM and - # per-channel-per-channel a8w8 scacled GEMM + # For now, it only supports: + # - per-tensor-per-tensor a8w8 scaled GEMM, and + # - per-token-per-channel a8w8 scaled GEMM assert ((per_tensor_scale_a and per_tensor_scale_b) or (per_token_scale_a and per_channel_scale_b)), ( "Currently only support per-tensor-per-tensor GEMM " + - " and per-channel-per-channel GEMM through AITER" - " w8a8 scaled gemm. `cutlass_scaled_mm` does not support" + - " ATIER block scaled GEMM yet.") + " and per-token-per-channel GEMM through AITER" + " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + + "does not support AITER block scaled GEMM.") from aiter import gemm_a8w8_CK