|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +from collections.abc import Sequence |
| 4 | + |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | +from torch._ops import OpOverload |
| 8 | + |
| 9 | +import vllm.plugins |
| 10 | +from vllm.compilation.fusion import ( |
| 11 | + QUANT_OPS, |
| 12 | + FusedRMSQuantKey, |
| 13 | +) |
| 14 | +from vllm.compilation.noop_elimination import NoOpEliminationPass |
| 15 | +from vllm.compilation.post_cleanup import PostCleanupPass |
| 16 | +from vllm.compilation.rocm_aiter_rmsnorm_fusion import ( |
| 17 | + ROCM_AITER_FUSED_OPS, |
| 18 | + RMSNormAiterQuantFusionPass, |
| 19 | +) |
| 20 | +from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig |
| 21 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 22 | +from vllm.model_executor.layers.quantization.utils.quant_utils import ( |
| 23 | + GroupShape, |
| 24 | + QuantKey, |
| 25 | + ScaleDesc, |
| 26 | +) |
| 27 | +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( |
| 28 | + Fp8LinearOp, |
| 29 | + maybe_create_device_identity, |
| 30 | +) |
| 31 | +from vllm.platforms import current_platform |
| 32 | + |
| 33 | +from .backend import TestBackend |
| 34 | + |
| 35 | +FP8_DTYPE = current_platform.fp8_dtype() |
| 36 | + |
| 37 | + |
| 38 | +class TestModel(torch.nn.Module): |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + hidden_size: int, |
| 42 | + eps: float, |
| 43 | + *args, |
| 44 | + **kwargs, |
| 45 | + ): |
| 46 | + super().__init__(*args, **kwargs) |
| 47 | + self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] |
| 48 | + group_shape = GroupShape.PER_TOKEN |
| 49 | + # AITER RMSNorm fusion pass does not support static quantization at the moment. |
| 50 | + self.wscale = [ |
| 51 | + torch.rand(size=(hidden_size, 1), dtype=torch.float32) for _ in range(2) |
| 52 | + ] |
| 53 | + quant_scale = ScaleDesc(torch.float32, static=False, group_shape=group_shape) |
| 54 | + self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) |
| 55 | + |
| 56 | + self.scale = [None for _ in range(2)] |
| 57 | + self.w = [ |
| 58 | + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() |
| 59 | + for _ in range(2) |
| 60 | + ] |
| 61 | + |
| 62 | + self.fp8_linear = Fp8LinearOp( |
| 63 | + act_quant_static=False, |
| 64 | + act_quant_group_shape=group_shape, |
| 65 | + ) |
| 66 | + |
| 67 | + def forward(self, x): |
| 68 | + resid = torch.sqrt(x) |
| 69 | + y = self.norm[0](x) |
| 70 | + |
| 71 | + x2 = self.fp8_linear.apply( |
| 72 | + y, self.w[0], self.wscale[0], input_scale=self.scale[0] |
| 73 | + ) |
| 74 | + # make sure resid is used for replacement to work |
| 75 | + y2, resid = self.norm[1](x2, resid) |
| 76 | + |
| 77 | + x3 = self.fp8_linear.apply( |
| 78 | + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] |
| 79 | + ) |
| 80 | + y3, resid = self.norm[2](x3, resid) # use resid here |
| 81 | + return y3 |
| 82 | + |
| 83 | + def ops_in_model_before(self) -> Sequence[tuple[OpOverload, bool]]: |
| 84 | + # find fp8 quant ops in the model before fusion using |
| 85 | + # its funcationalized version (without directly targeting the function). |
| 86 | + return [(QUANT_OPS[self.key], False)] |
| 87 | + |
| 88 | + def ops_in_model_after(self) -> Sequence[tuple[OpOverload, bool]]: |
| 89 | + # find aiter rmsnorm fused ops in the model |
| 90 | + # after fusion by directly targeting the function. |
| 91 | + |
| 92 | + return [ |
| 93 | + (ROCM_AITER_FUSED_OPS[FusedRMSQuantKey(self.key, False)], True), |
| 94 | + (ROCM_AITER_FUSED_OPS[FusedRMSQuantKey(self.key, True)], True), |
| 95 | + ] |
| 96 | + |
| 97 | + |
| 98 | +@pytest.mark.parametrize("dtype", [torch.bfloat16]) |
| 99 | +@pytest.mark.parametrize("hidden_size", [2048]) |
| 100 | +@pytest.mark.parametrize("num_tokens", [257]) |
| 101 | +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) |
| 102 | +@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only test on ROCm") |
| 103 | +def test_fusion_rmsnorm_quant( |
| 104 | + dtype: torch.dtype, |
| 105 | + hidden_size: int, |
| 106 | + num_tokens: int, |
| 107 | + eps: float, |
| 108 | + monkeypatch: pytest.MonkeyPatch, |
| 109 | +): |
| 110 | + torch.set_default_device("cuda") |
| 111 | + torch.set_default_dtype(dtype) |
| 112 | + torch.manual_seed(1) |
| 113 | + maybe_create_device_identity() # needed for certain non-cutlass fp8 paths |
| 114 | + |
| 115 | + vllm_config = VllmConfig( |
| 116 | + compilation_config=CompilationConfig( |
| 117 | + level=CompilationLevel.PIECEWISE, |
| 118 | + custom_ops=["+rms_norm", "+quant_fp8"], |
| 119 | + pass_config=PassConfig(enable_fusion=True, enable_noop=True), |
| 120 | + ) |
| 121 | + ) |
| 122 | + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: |
| 123 | + m.setenv("VLLM_ROCM_USE_AITER", "1") |
| 124 | + m.setenv("VLLM_ROCM_USE_AITER_LINEAR", "0") |
| 125 | + m.setenv("VLLM_ROCM_USE_AITER_RMSNORM", "1") |
| 126 | + |
| 127 | + # Reshape pass is needed for the fusion pass to work |
| 128 | + noop_pass = NoOpEliminationPass(vllm_config) |
| 129 | + fusion_pass = RMSNormAiterQuantFusionPass(vllm_config) |
| 130 | + cleanup_pass = PostCleanupPass(vllm_config) |
| 131 | + |
| 132 | + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) |
| 133 | + model = TestModel(hidden_size, eps) |
| 134 | + |
| 135 | + # First dimension dynamic |
| 136 | + x = torch.rand(num_tokens, hidden_size) |
| 137 | + torch._dynamo.mark_dynamic(x, 0) |
| 138 | + |
| 139 | + result = model(x) |
| 140 | + |
| 141 | + model2 = torch.compile(model, backend=backend) |
| 142 | + result2 = model2(x) |
| 143 | + |
| 144 | + ATOL, RTOL = (1e-2, 1e-2) |
| 145 | + |
| 146 | + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) |
| 147 | + |
| 148 | + assert fusion_pass.matched_count == 2 |
| 149 | + |
| 150 | + # In pre-nodes, fp8 quant should be there and fused kernels should not |
| 151 | + backend.check_before_fused_auto_custom_ops(model.ops_in_model_before()) |
| 152 | + |
| 153 | + # In post-nodes, fused kernels should be there and fp8 quant should not |
| 154 | + backend.check_after_fused_auto_custom_ops(model.ops_in_model_after()) |
0 commit comments