diff --git a/Dockerfile.rocm_base b/Dockerfile.rocm_base index 5bbe98b0c2204..e33e73b303098 100644 --- a/Dockerfile.rocm_base +++ b/Dockerfile.rocm_base @@ -6,7 +6,7 @@ ARG RCCL_BRANCH="648a58d" ARG RCCL_REPO="https://github.com/ROCm/rccl" ARG TRITON_BRANCH="e5be006" ARG TRITON_REPO="https://github.com/triton-lang/triton.git" -ARG PYTORCH_BRANCH="8d4926e" +ARG PYTORCH_BRANCH="3a585126" ARG PYTORCH_VISION_BRANCH="v0.19.1" ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" diff --git a/docs/source/getting_started/installation/gpu/rocm.inc.md b/docs/source/getting_started/installation/gpu/rocm.inc.md index c8fd11415cfda..336d578de4032 100644 --- a/docs/source/getting_started/installation/gpu/rocm.inc.md +++ b/docs/source/getting_started/installation/gpu/rocm.inc.md @@ -1,6 +1,6 @@ # Installation -vLLM supports AMD GPUs with ROCm 6.2. +vLLM supports AMD GPUs with ROCm 6.3. :::{attention} There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. @@ -9,7 +9,7 @@ There are no pre-built wheels for this device, so you must either use the pre-bu ## Requirements - GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100) -- ROCm 6.2 +- ROCm 6.3 ## Set up using Python @@ -24,9 +24,15 @@ Currently, there are no pre-built ROCm wheels. - [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html) - [PyTorch](https://pytorch.org/) - For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`. + For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3. - Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/) + Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example: + + ```console + # Install PyTorch + $ pip uninstall torch -y + $ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.3 + ``` 1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton) @@ -37,7 +43,7 @@ Currently, there are no pre-built ROCm wheels. pip uninstall -y triton git clone https://github.com/OpenAI/triton.git cd triton - git checkout e192dba + git checkout e5be006 cd python pip3 install . cd ../.. @@ -49,15 +55,15 @@ Currently, there are no pre-built ROCm wheels. 2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile) - Install ROCm's flash attention (v2.5.9.post1) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support) + Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support) Alternatively, wheels intended for vLLM use can be accessed under the releases. - For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. + For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. ```console git clone https://github.com/ROCm/flash-attention.git cd flash-attention - git checkout 3cea2fb + git checkout b7d29fb git submodule update --init GPU_ARCHS="gfx90a" python3 setup.py install cd .. @@ -67,20 +73,16 @@ Currently, there are no pre-built ROCm wheels. You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) ::: -3. Build vLLM. For example, vLLM on ROCM 6.2 can be built with the following steps: +3. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps: ```bash $ pip install --upgrade pip - # Install PyTorch - $ pip uninstall torch -y - $ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.2 - # Build & install AMD SMI $ pip install /opt/rocm/share/amd_smi # Install dependencies - $ pip install --upgrade numba scipy huggingface-hub[cli] + $ pip install --upgrade numba scipy huggingface-hub[cli,hf_transfer] setuptools_scm $ pip install "numpy<2" $ pip install -r requirements-rocm.txt @@ -104,7 +106,7 @@ Currently, there are no pre-built ROCm wheels. For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization). ::: -## Set up using Docker +## Set up using Docker (Recommended) ### Pre-built images @@ -120,7 +122,12 @@ for instructions on how to use this prebuilt docker image. Building the Docker image from source is the recommended way to use vLLM with ROCm. -First, build a docker image from and launch a docker container from the image. +#### (Optional) Build an image with ROCm software stack + +Build a docker image from which setup ROCm software stack needed by the vLLM. +**This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.** +If you choose to build this rocm_base image yourself, the steps are as follows. + It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```console @@ -131,7 +138,26 @@ It is important that the user kicks off the docker build using buildkit. Either } ``` - uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches. +To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: + +```console +DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm_base -t rocm/vllm-dev:base . +``` + +#### Build an image with vLLM + +First, build a docker image from and launch a docker container from the image. +It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: + +```console +{ + "features": { + "buildkit": true + } +} +``` + + uses ROCm 6.3 by default, but also supports ROCm 5.7, 6.0, 6.1, and 6.2, in older vLLM branches. It provides flexibility to customize the build of docker image using the following arguments: - `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using @@ -141,13 +167,13 @@ It provides flexibility to customize the build of docker image using the followi Their values can be passed in when running `docker build` with `--build-arg` options. -To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default: +To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: ```console DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm . ``` -To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: +To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: ```console DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f Dockerfile.rocm -t vllm-rocm . diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 5616935ebdc0c..3a7f0a196b5b8 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -55,10 +55,21 @@ def check_model(model): assert isinstance(attn.quant_method, Fp8KVCacheMethod) - # NOTE: it is valid for scales to be 1.0 (default value), but - # we know these checkpoints have scales < 1.0 - assert 0.0 < attn._k_scale < 1.0 - assert 0.0 < attn._v_scale < 1.0 + if not current_platform.is_rocm(): + # NOTE: This code path requires validation on Non-CUDA platform + # NOTE: it is valid for scales to be 1.0 (default value), but + # we know these checkpoints have scales < 1.0 + assert 0.0 < attn._k_scale < 1.0 + assert 0.0 < attn._v_scale < 1.0 + else: + # NOTE: This code path is for ROCm platform + # NOTE: it is valid for scales to be 1.0 (default value), but + # we know these checkpoints have scales < 1.0 + # However on ROCm platform, the _k_scale and _v_scale will be + # scaled by a factor of 2 as described in + # vllm/model_executor/layers/quantization/kv_cache.py + assert 0.0 < attn._k_scale < (1.0 * 2.0) + assert 0.0 < attn._v_scale < (1.0 * 2.0) llm.apply_model(check_model) @@ -91,13 +102,29 @@ def check_model(model): assert attn._k_scale == 1.0 assert attn._v_scale == 1.0 - if current_platform.has_device_capability(89) and not force_marlin: - # For GPUs with hardware support, we keep weights in fp8 - assert fc1.weight.dtype == torch.float8_e4m3fn - else: - # For GPUs without hardware support, we pack the fp8 weights - # for weight-only quantization using Marlin kernels - assert fc1.weight.dtype == torch.int32 + if current_platform.is_cuda(): + if current_platform.has_device_capability( + 89) and not force_marlin: + # For GPUs with hardware support, we keep weights in fp8 + assert fc1.weight.dtype == torch.float8_e4m3fn + else: + # For GPUs without hardware support, we pack the fp8 weights + # for weight-only quantization using Marlin kernels + assert fc1.weight.dtype == torch.int32 + elif current_platform.is_rocm(): + # Only MI300 and above support quantization='fp8' + if current_platform.has_device_capability( + 94) and not force_marlin: + # For GPUs with hardware support, we keep weights in fp8 + assert fc1.weight.dtype == torch.float8_e4m3fnuz + else: # unsupported ROCm platform + pytest.skip( + "Skip `test_load_fp16_model`. " + "It only runs on ROCm platform with FP8 compute." + " e.g. MI300X and above.") + else: # unsupported platform + pytest.skip("Skip `test_load_fp16_model`. " + "It only runs on CUDA and ROCm platform.") llm.apply_model(check_model) diff --git a/tests/quantization/test_ptpc_fp8.py b/tests/quantization/test_ptpc_fp8.py new file mode 100644 index 0000000000000..9bbb5e327968f --- /dev/null +++ b/tests/quantization/test_ptpc_fp8.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests whether PTPC w8a8 FP8 computation is enabled correctly. + +Run `pytest tests/quantization/test_ptpc_fp8.py --forked`. +""" +import pytest +import torch + +from tests.quantization.utils import is_quant_method_supported +from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod +from vllm.model_executor.layers.quantization.ptpc_fp8 import ( + PTPCFp8LinearMethod) +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"), + reason="PTPC FP8 is not supported on this GPU type.") +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="This test is for ROCm GPU.") +@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) +def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: + + try: + with vllm_runner("facebook/opt-125m", + dtype=dtype, + quantization="ptpc_fp8", + kv_cache_dtype=kv_cache_dtype) as llm: + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + fc1 = model.model.decoder.layers[0].fc1 + assert isinstance(fc1.quant_method, PTPCFp8LinearMethod) + if kv_cache_dtype == "ptpc_fp8": + attn = model.model.decoder.layers[0].self_attn.attn + assert isinstance(attn.quant_method, Fp8KVCacheMethod) + assert attn._k_scale == 1.0 + assert attn._v_scale == 1.0 + + if current_platform.has_device_capability(94): + # For GPUs with hardware support, we keep weights in fp8 + assert fc1.weight.dtype == torch.float8_e4m3fnuz + else: + pytest.skip() + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + except AssertionError as e: + if str( + e + ) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501 + # If the error message matches, the test passes + pass + else: + # If the error message does not match, re-raise the exception + raise diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 6ded3874fc1dd..6cd508d057a44 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -11,6 +11,7 @@ "deepspeedfp", "tpu_int8", "fp8", + "ptpc_fp8", "fbgemm_fp8", "modelopt", # The order of gptq methods is important for config.py iteration over @@ -99,6 +100,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .modelopt import ModelOptFp8Config from .moe_wna16 import MoeWNA16Config from .neuron_quant import NeuronQuantConfig + from .ptpc_fp8 import PTPCFp8Config from .qqq import QQQConfig from .tpu_int8 import Int8TpuConfig @@ -120,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, + "ptpc_fp8": PTPCFp8Config, "qqq": QQQConfig, "hqq": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py new file mode 100644 index 0000000000000..1ded5389e5f45 --- /dev/null +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) +from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, + Fp8KVCacheMethod, + Fp8LinearMethod) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear) +from vllm.platforms import current_platform + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class PTPCFp8Config(Fp8Config): + """Config class for Per-Token-Per-Channel Dynamic Quantization Fp8.""" + + def __init__( + self, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + ) -> None: + if not current_platform.is_rocm(): + raise ValueError( + "ptpc_fp8 quantization is supported only on ROCm.") + + if not current_platform.has_device_capability(94): + raise ValueError( + "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501 + ) + if activation_scheme == "static": + raise ValueError( + "ptpc_fp8 as of now only support dynamic quantization.") + + super().__init__(is_checkpoint_fp8_serialized=False, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + @classmethod + def get_name(cls) -> str: + return "ptpc_fp8" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config": + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + return cls(activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return PTPCFp8LinearMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + +class PTPCFp8LinearMethod(Fp8LinearMethod): + """Linear method for Per-Token and Per-Channel FP8 Quantization. + Only supports loading quantized BF16 model checkpoints with dynamic + activation scaling. To load FP16 model checkpoints, user must specify + to convert the FP16 model weight loading into BF16. + The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support float8_e4m3fnuz data type due to the limitation of + torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: PTPCFp8Config): + super().__init__(quant_config=quant_config) + # Force weight quantization + self.quant_config.is_checkpoint_fp8_serialized = False + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + + assert layer.weight.data.dtype == torch.bfloat16, \ + f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 + # Quantize the weights. + qweight, weight_scale = ops.scaled_fp8_quant( + layer.weight, scale=None, use_per_token_if_dynamic=True) + + # Update the layer with the new values. + layer.weight = Parameter( + qweight.t(), requires_grad=False) # Pretranspose the weight + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=None, + bias=bias, + cutlass_fp8_supported=False, + use_per_token_if_dynamic=True) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index dedeb0c296bd4..bea6390f71ff7 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -11,6 +11,13 @@ # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +# The condition to determine if it is on a platform that supports +# torch._scaled_mm rowwise feature. +# The condition is determined once as the operations +# are time consuming. +USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() + and current_platform.has_device_capability(94)) + def sparse_cutlass_supported() -> bool: if not current_platform.is_cuda(): @@ -172,6 +179,26 @@ def apply_fp8_linear( return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + elif (use_per_token_if_dynamic and not per_tensor_weights + and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): + # For now validated on ROCm platform + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # For CUDA platform please validate if the + # torch._scaled_mm support rowwise scaled GEMM + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale.t(), + bias=bias) + + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = output.view(*output_shape) + return output + else: # Fallback for channelwise case, where we use unfused DQ # due to limitations with scaled_mm diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 035766289aebd..1f690b7111ee2 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -72,7 +72,7 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", - "fbgemm_fp8", "gguf", "quark" + "fbgemm_fp8", "gguf", "quark", "ptpc_fp8" ] @classmethod