diff --git a/tests/compile/test_rocm_aiter_fusion.py b/tests/compile/test_rocm_aiter_fusion.py new file mode 100644 index 000000000000..65919e0065d8 --- /dev/null +++ b/tests/compile/test_rocm_aiter_fusion.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence + +import pytest +import torch +from torch._ops import OpOverload + +import vllm.plugins +from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fusion import ( + QUANT_OPS, + FusedRMSQuantKey, +) +from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.rocm_aiter_rmsnorm_fusion import ( + FusedAddRMSNormAiterDynamicQuantPattern, + RMSNormAiterDynamicQuantPattern, + RMSNormAiterQuantFusionPass, +) +from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + QuantKey, + ScaleDesc, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, + maybe_create_device_identity, +) +from vllm.platforms import current_platform + +from .backend import TestBackend + +FP8_DTYPE = current_platform.fp8_dtype() + + +class TestModel(torch.nn.Module): + def __init__( + self, + hidden_size: int, + eps: float, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] + group_shape = GroupShape.PER_TOKEN + # AITER RMSNorm fusion pass does not support static quantization at the moment. + self.wscale = [ + torch.rand(size=(hidden_size, 1), dtype=torch.float32) for _ in range(2) + ] + quant_scale = ScaleDesc(torch.float32, static=False, group_shape=group_shape) + self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + + self.scale = [None for _ in range(2)] + self.w = [ + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + for _ in range(2) + ] + + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, + act_quant_group_shape=group_shape, + ) + + def forward(self, x): + resid = torch.sqrt(x) + y = self.norm[0](x) + + x2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) + # make sure resid is used for replacement to work + y2, resid = self.norm[1](x2, resid) + + x3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + ) + y3, resid = self.norm[2](x3, resid) # use resid here + return y3 + + def ops_in_model_before(self) -> Sequence[OpOverload]: + return [(QUANT_OPS[self.key])] + + def ops_in_model_after(self) -> Sequence[OpOverload]: + ROCM_AITER_FUSED_OPS = ( + FusedAddRMSNormAiterDynamicQuantPattern.ROCM_AITER_FUSED_OPS + | RMSNormAiterDynamicQuantPattern.ROCM_AITER_FUSED_OPS + ) + return [ + (ROCM_AITER_FUSED_OPS[FusedRMSQuantKey(self.key, False)]), + (ROCM_AITER_FUSED_OPS[FusedRMSQuantKey(self.key, True)]), + ] + + def ops_in_model(self): + return [torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default] + + def ops_not_in_model(self): + return [] + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("num_tokens", [257]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only test on ROCm") +def test_fusion_rmsnorm_quant( + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + eps: float, + monkeypatch: pytest.MonkeyPatch, +): + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + maybe_create_device_identity() # needed for certain non-cutlass fp8 paths + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ) + ) + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: + m.setenv("VLLM_ROCM_USE_AITER", "1") + m.setenv("VLLM_ROCM_USE_AITER_RMSNORM", "1") + + # Reshape pass is needed for the fusion pass to work + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = RMSNormAiterQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + + model = TestModel(hidden_size, eps) + + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + + result2 = model2(x) + + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + + assert fusion_pass.matched_count == 2 + + # In pre-nodes, fp8 quant should be there and fused kernels should not + backend.check_before_ops(model.ops_in_model_before()) + + # In post-nodes, fused kernels should be there and fp8 quant should not + backend.check_after_ops(model.ops_in_model_after()) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("num_tokens", [257]) +@pytest.mark.parametrize("eps", [1e-6]) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only test on ROCm") +def test_fix_functionalization( + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + eps: float, + monkeypatch: pytest.MonkeyPatch, +): + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ) + ) + with monkeypatch.context() as m: + m.setenv("VLLM_ROCM_USE_AITER", "1") + m.setenv("VLLM_ROCM_USE_AITER_RMSNORM", "1") + + # Reshape pass is needed for the fusion pass to work + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = RMSNormAiterQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + passes = [noop_pass, fusion_pass, cleanup_pass] + func_pass = FixFunctionalizationPass(vllm_config) + + backend_no_func = TestBackend(*passes) + backend_func = TestBackend(*passes, func_pass) + + model = TestModel(hidden_size, eps) + + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + + torch.compile(model, backend=backend_no_func)(x) + torch.compile(model, backend=backend_func)(x) + + # check if the functionalization pass is applied + for op in model.ops_in_model(): + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in model.ops_in_model(): + if is_func(node, op): + found[op] = True + for op in model.ops_not_in_model(): + if is_func(node, op): + found[op] = True + assert all(found[op] for op in model.ops_in_model()) + assert all(not found.get(op) for op in model.ops_not_in_model()) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 29462d9ff0e5..a66727419c03 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -97,6 +97,21 @@ def __call__(self, graph: torch.fx.Graph): elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 mutated_args = {1: "result", 2: "scale", 3: "residual"} self.defunctionalize(graph, node, mutated_args) + elif ( + at_target + == torch.ops.vllm.rocm_aiter_rmsnorm_fused_dynamic_quant.default + ): + mutated_args = {1: "out", 2: "y_scale"} + self.defunctionalize(graph, node, mutated_args) + elif ( + at_target + == torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default + ): + mutated_args = {1: "out", 2: "residual_out", 3: "y_scale"} + self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default: # noqa: E501 + mutated_args = {1: "output", 2: "residual_out"} + self.defunctionalize(graph, node, mutated_args) elif at_target in [ torch.ops._C.rms_norm.default, torch.ops._C.rms_norm_static_fp8_quant.default, diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 55fe235e2d2c..724b7e1df62a 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -18,6 +18,12 @@ from .fusion import RMSNormQuantFusionPass from .fusion_attn import AttnFusionPass +if current_platform.is_rocm(): + from .rocm_aiter_rmsnorm_fusion import ( + RMSNormAiterQuantFusionPass, + is_rocm_aiter_rmsnorm_enabled, + ) + if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass @@ -100,6 +106,9 @@ def configure(self, config: VllmConfig): self.passes += [AllReduceFusionPass(config)] if self.pass_config.enable_fusion: + if is_rocm_aiter_rmsnorm_enabled(): + self.passes += [RMSNormAiterQuantFusionPass(config)] + self.passes += [RMSNormQuantFusionPass(config)] self.passes += [ActivationQuantFusionPass(config)] diff --git a/vllm/compilation/rocm_aiter_rmsnorm_fusion.py b/vllm/compilation/rocm_aiter_rmsnorm_fusion.py new file mode 100644 index 000000000000..791923f18111 --- /dev/null +++ b/vllm/compilation/rocm_aiter_rmsnorm_fusion.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload + +import vllm.envs as envs + +# add this import to make sure the custom ops are registered +import vllm.model_executor.layers.layernorm # noqa: F401 +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + QuantKey, + ScaleDesc, + kFp8DynamicTokenSym, +) +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + +from .fusion import ( + FP8_DTYPE, + QUANT_OPS, + FusedRMSQuantKey, + RMSNormQuantPattern, + empty_bf16, + empty_fp32, +) +from .inductor_pass import enable_fake_mode +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) + + +def is_rocm_aiter_rmsnorm_enabled() -> bool: + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER_RMSNORM + and envs.VLLM_ROCM_USE_AITER + ) + + +def rocm_aiter_rmsnorm_fused_dynamic_quant_impl( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + y_scale: torch.Tensor, + epsilon: float, +) -> None: + import aiter as rocm_aiter + + rocm_aiter.rmsnorm2d_fwd_with_dynamicquant( + out, input, y_scale, weight, epsilon, use_model_sensitive_rmsnorm=0 + ) + + +def rocm_aiter_rmsnorm_fused_dynamic_quant_fake( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + y_scale: torch.Tensor, + epsilon: float, +) -> None: + pass + + +def rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl( + out: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + residual_out: torch.Tensor, + weight: torch.Tensor, + y_scale: torch.Tensor, + epsilon: float, +) -> None: + import aiter as rocm_aiter + + rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant( + out, + input, + residual, + residual_out, + y_scale, + weight, + epsilon, + use_model_sensitive_rmsnorm=0, + ) + + +def rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake( + out: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + residual_out: torch.Tensor, + weight: torch.Tensor, + y_scale: torch.Tensor, + epsilon: float, +) -> None: + pass + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fused_dynamic_quant", + op_func=rocm_aiter_rmsnorm_fused_dynamic_quant_impl, + mutates_args=["out", "y_scale"], + fake_impl=rocm_aiter_rmsnorm_fused_dynamic_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fused_add_dynamic_quant", + op_func=rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl, + mutates_args=["out", "residual_out", "y_scale"], + fake_impl=rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +class RMSNormAiterQuantPattern(RMSNormQuantPattern): + def __init__(self, epsilon, key): + self.epsilon = epsilon + self.quant_dtype = key.quant.dtype + + assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" + self.QUANT_OP = QUANT_OPS[key.quant] + + +class RMSNormAiterDynamicQuantPattern(RMSNormAiterQuantPattern): + """AITER RMSNorm + Dynamic Quantization pattern.""" + + ROCM_AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default + + ROCM_AITER_FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { + FusedRMSQuantKey( + kFp8DynamicTokenSym, + False, + ): torch.ops.vllm.rocm_aiter_rmsnorm_fused_dynamic_quant.default, + } + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + + assert key in self.ROCM_AITER_FUSED_OPS, ( + f"unsupported fused aiter rmsnorm+quant op for {key}" + ) + self.FUSED_OP = self.ROCM_AITER_FUSED_OPS[key] + + super().__init__(epsilon, key) + + def register(self, pm_pass): + def pattern( + result: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + rms_out = self.ROCM_AITER_RMS_OP( + x=input, + weight=weight, + variance_epsilon=self.epsilon, + ) + + at = auto_functionalized( + self.QUANT_OP, result=result, input=rms_out, scale=scale, scale_ub=None + ) + + return at[1], at[2] + + def replacement( + result: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + out=result, + input=input, + weight=weight, + y_scale=scale, + epsilon=self.epsilon, + ) + + return at[1], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1), # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + ) + + +class FusedAddRMSNormAiterDynamicQuantPattern(RMSNormAiterQuantPattern): + """AITER RMSNorm Fused Add + Dynamic Quantization pattern.""" + + ROCM_AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default + + ROCM_AITER_FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { + FusedRMSQuantKey( + kFp8DynamicTokenSym, + True, + ): torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default, + } + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + + assert key in self.ROCM_AITER_FUSED_OPS, ( + f"unsupported fused aiter rmsnorm+quant op for {key}" + ) + self.FUSED_OP = self.ROCM_AITER_FUSED_OPS[key] + + super().__init__(epsilon, key) + + def register(self, pm_pass): + def pattern( + result: torch.Tensor, + rms_result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + residual_out: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.ROCM_AITER_RMS_ADD_OP, + output=rms_result, + x=input, + residual=residual, + residual_out=residual_out, + weight=weight, + variance_epsilon=self.epsilon, + ) + + at1 = auto_functionalized( + self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None + ) + + return at1[1], at[2], at1[2] + + def replacement( + result: torch.Tensor, + rms_result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + residual_out: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + out=result, + input=input, + residual=residual, + residual_out=residual_out, + weight=weight, + y_scale=scale, + epsilon=self.epsilon, + ) + # result, residual, scale + return at[1], at[2], at[3] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(5, 4), # residual_out + empty_bf16(1, 5), # weight + empty_fp32(1, 1), # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + ) + + +class RMSNormAiterQuantFusionPass(VllmPatternMatcherPass): + """ + This pass fuses aiter rms_norm & quant custom ops into a fused rms_norm_quant op. + It also supports aiter fused_add_rms_norm. + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="aiter_rmsnorm_quant_fusion_pass" + ) + + for epsilon in [1e-5, 1e-6]: + # Fuse aiter rms_norm + dynamic per-token fp8 quant + RMSNormAiterDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + + # Fuse aiter fused_add_rms_norm + dynamic per-token fp8 quant + FusedAddRMSNormAiterDynamicQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns + ) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self) -> Any: + return self.hash_source( + self, + RMSNormQuantPattern, + RMSNormAiterDynamicQuantPattern, + FusedAddRMSNormAiterDynamicQuantPattern, + ) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 135fbda2d540..7fa19d84cdfe 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -79,15 +79,15 @@ def rocm_aiter_rms_norm_impl( def rocm_aiter_rmsnorm2d_fwd_with_add_impl( + output: torch.Tensor, x: torch.Tensor, residual: torch.Tensor, + residual_out: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> None: import aiter as rocm_aiter - residual_out = torch.empty_like(residual) - output = torch.empty_like(x) rocm_aiter.rmsnorm2d_fwd_with_add( output, # output x, # input @@ -96,7 +96,6 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_impl( weight, variance_epsilon, ) - return output, residual_out def rocm_aiter_rms_norm_fake( @@ -106,12 +105,14 @@ def rocm_aiter_rms_norm_fake( def rocm_aiter_rmsnorm2d_fwd_with_add_fake( + output: torch.Tensor, x: torch.Tensor, residual: torch.Tensor, + residual_out: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(x), torch.empty_like(residual) +) -> None: + pass if current_platform.is_rocm(): @@ -124,27 +125,11 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_fake( direct_register_custom_op( op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, + mutates_args=["output", "residual_out"], fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, ) -def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): - use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ - torch.float16, - torch.bfloat16, - ] - - if use_aiter and with_fused_add: - return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add - if use_aiter: - return torch.ops.vllm.rocm_aiter_rms_norm - - # fall back to CUDA implementation - if with_fused_add: - return fused_add_rms_norm - return rms_norm - - @CustomOp.register("rms_norm") class RMSNorm(CustomOp): """Root mean square normalization. @@ -177,13 +162,10 @@ def __init__( self.weight = nn.Parameter(self.weight) weight_dtype = self.weight.data.dtype - if current_platform.is_rocm(): - self.rocm_norm_func = dispatch_rocm_rmsnorm_func( - with_fused_add=False, dtype=weight_dtype - ) - self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( - with_fused_add=True, dtype=weight_dtype - ) + self.rocm_use_aiter = is_rocm_aiter_rmsnorm_enabled() and weight_dtype in [ + torch.float16, + torch.bfloat16, + ] def forward_native( self, @@ -251,12 +233,26 @@ def forward_hip( return self.forward_native(x, residual) add_residual = residual is not None - if add_residual: - return self.rocm_norm_func_with_add( - x, residual, self.weight.data, self.variance_epsilon - ) - else: - return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon) + if self.rocm_use_aiter: + if add_residual: + residual_out = torch.empty_like(residual) + output = torch.empty_like(x) + + torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add( + output, + x, + residual, + residual_out, + self.weight.data, + self.variance_epsilon, + ) + return output, residual_out + else: + return torch.ops.vllm.rocm_aiter_rms_norm( + x, self.weight.data, self.variance_epsilon + ) + + return self.forward_cuda(x, residual) def forward_xpu( self,