diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 36bc832a1329..91d9eae50de8 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -88,6 +88,36 @@ def check_after_ops(self, ops: Sequence[OpOverload]): assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" assert num_post > 0, f"Op {op.name()} not found in post-pass graph" + def check_before_fused_auto_custom_ops( + self, ops: Sequence[tuple[OpOverload, bool]], fully_replaced=True + ): + # currently only used for aiter custom ops that are + # registered with mutable scheme directly on vllm namespace + # while they are fused with auto_functionalized ops. + + for op, target_op_only in ops: + num_pre = len(list(find_op_nodes(op, self.graph_pre_pass, target_op_only))) + num_post = len( + list(find_op_nodes(op, self.graph_post_pass, target_op_only)) + ) + assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph" + assert num_pre > num_post, f"All nodes remain for op {op.name()}" + if fully_replaced: + assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph" + + def check_after_fused_auto_custom_ops(self, ops: Sequence[tuple[OpOverload, bool]]): + # currently only used for aiter custom ops that + # are registered with mutable scheme directly on vllm namespace + # while they are fused with auto_functionalized ops. + + for op, target_op_only in ops: + num_pre = len(list(find_op_nodes(op, self.graph_pre_pass, target_op_only))) + num_post = len( + list(find_op_nodes(op, self.graph_post_pass, target_op_only)) + ) + assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" + assert num_post > 0, f"Op {op.name()} not found in post-pass graph" + def op_count(self, op: OpOverload, before=False) -> int: graph = self.graph_pre_pass if before else self.graph_post_pass return len(list(find_op_nodes(op, graph))) diff --git a/tests/compile/test_rocm_aiter_fusion.py b/tests/compile/test_rocm_aiter_fusion.py new file mode 100644 index 000000000000..bcc922897f67 --- /dev/null +++ b/tests/compile/test_rocm_aiter_fusion.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence + +import pytest +import torch +from torch._ops import OpOverload + +import vllm.plugins +from vllm.compilation.fusion import ( + QUANT_OPS, + FusedRMSQuantKey, +) +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.rocm_aiter_rmsnorm_fusion import ( + ROCM_AITER_FUSED_OPS, + RMSNormAiterQuantFusionPass, +) +from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + QuantKey, + ScaleDesc, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, + maybe_create_device_identity, +) +from vllm.platforms import current_platform + +from .backend import TestBackend + +FP8_DTYPE = current_platform.fp8_dtype() + + +class TestModel(torch.nn.Module): + def __init__( + self, + hidden_size: int, + eps: float, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] + group_shape = GroupShape.PER_TOKEN + # AITER RMSNorm fusion pass does not support static quantization at the moment. + self.wscale = [ + torch.rand(size=(hidden_size, 1), dtype=torch.float32) for _ in range(2) + ] + quant_scale = ScaleDesc(torch.float32, static=False, group_shape=group_shape) + self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + + self.scale = [None for _ in range(2)] + self.w = [ + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + for _ in range(2) + ] + + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, + act_quant_group_shape=group_shape, + ) + + def forward(self, x): + resid = torch.sqrt(x) + y = self.norm[0](x) + + x2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) + # make sure resid is used for replacement to work + y2, resid = self.norm[1](x2, resid) + + x3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + ) + y3, resid = self.norm[2](x3, resid) # use resid here + return y3 + + def ops_in_model_before(self) -> Sequence[tuple[OpOverload, bool]]: + # find fp8 quant ops in the model before fusion using + # its funcationalized version (without directly targeting the function). + return [(QUANT_OPS[self.key], False)] + + def ops_in_model_after(self) -> Sequence[tuple[OpOverload, bool]]: + # find aiter rmsnorm fused ops in the model + # after fusion by directly targeting the function. + + return [ + (ROCM_AITER_FUSED_OPS[FusedRMSQuantKey(self.key, False)], True), + (ROCM_AITER_FUSED_OPS[FusedRMSQuantKey(self.key, True)], True), + ] + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [2048]) +@pytest.mark.parametrize("num_tokens", [257]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only test on ROCm") +def test_fusion_rmsnorm_quant( + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + eps: float, + monkeypatch: pytest.MonkeyPatch, +): + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + maybe_create_device_identity() # needed for certain non-cutlass fp8 paths + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ) + ) + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: + m.setenv("VLLM_ROCM_USE_AITER", "1") + m.setenv("VLLM_ROCM_USE_AITER_LINEAR", "0") + m.setenv("VLLM_ROCM_USE_AITER_RMSNORM", "1") + + # Reshape pass is needed for the fusion pass to work + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = RMSNormAiterQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + model = TestModel(hidden_size, eps) + + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + + assert fusion_pass.matched_count == 2 + + # In pre-nodes, fp8 quant should be there and fused kernels should not + backend.check_before_fused_auto_custom_ops(model.ops_in_model_before()) + + # In post-nodes, fused kernels should be there and fp8 quant should not + backend.check_after_fused_auto_custom_ops(model.ops_in_model_after()) diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index 114b53c74c48..ddd8aff4e268 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -67,8 +67,29 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node: # An auto-functionalization-aware utility for finding nodes with a specific op -def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]: - if not op._schema.is_mutable: +def find_op_nodes( + op: OpOverload, graph: fx.Graph, target_op_only: bool = False +) -> Iterator[fx.Node]: + """ + Yields all nodes in the graph that call the given op. + op (OpOverload): + The operator overload to match within the FX graph. + graph (fx.Graph): + The FX graph to search for nodes calling the specified operator. + target_op_only (bool): + If True, only yields nodes that directly call the specified operator. + If False, also yields nodes that call + the operator via auto_functionalized. + This is useful when `op` + is a mutable or custom-registered operator + that does not have an auto-functionalized version. + """ + + # op can be mutable by default, not using auto_functionalized. + # op like aiter_rmsnorm_fused_dynamic_quant has mutable schema + # by default directly registered on vllm namespace. + # it is not auto functionalized. + if not op._schema.is_mutable or target_op_only: yield from graph.find_nodes(op="call_function", target=op) for n in graph.find_nodes(op="call_function", target=auto_functionalized): diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e323fa1f7734..6b080b0feeda 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -18,6 +18,12 @@ from .fusion import RMSNormQuantFusionPass from .fusion_attn import AttnFusionPass +if current_platform.is_rocm(): + from .rocm_aiter_rmsnorm_fusion import ( + RMSNormAiterQuantFusionPass, + is_rocm_aiter_rmsnorm_enabled, + ) + if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass @@ -98,6 +104,9 @@ def configure(self, config: VllmConfig): self.passes += [AllReduceFusionPass(config)] if self.pass_config.enable_fusion: + if is_rocm_aiter_rmsnorm_enabled(): + self.passes += [RMSNormAiterQuantFusionPass(config)] + self.passes += [RMSNormQuantFusionPass(config)] self.passes += [ActivationQuantFusionPass(config)] diff --git a/vllm/compilation/rocm_aiter_rmsnorm_fusion.py b/vllm/compilation/rocm_aiter_rmsnorm_fusion.py new file mode 100644 index 000000000000..e5027adf52e5 --- /dev/null +++ b/vllm/compilation/rocm_aiter_rmsnorm_fusion.py @@ -0,0 +1,332 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload + +import vllm.envs as envs + +# add this import to make sure the custom ops are registered +import vllm.model_executor.layers.layernorm # noqa: F401 +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + QuantKey, + ScaleDesc, + kFp8DynamicTokenSym, +) +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + +from .fusion import ( + FP8_DTYPE, + QUANT_OPS, + FusedRMSQuantKey, + RMSNormQuantPattern, + empty_bf16, + empty_fp32, +) +from .inductor_pass import enable_fake_mode +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) + + +def is_rocm_aiter_rmsnorm_enabled() -> bool: + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER_RMSNORM + and envs.VLLM_ROCM_USE_AITER + ) + + +def rocm_aiter_rmsnorm_fused_dynamic_quant_impl( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + y_scale: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + import aiter as rocm_aiter + + rocm_aiter.rmsnorm2d_fwd_with_dynamicquant( + out, input, y_scale, weight, epsilon, use_model_sensitive_rmsnorm=0 + ) + + return out, y_scale + + +def rocm_aiter_rmsnorm_fused_dynamic_quant_fake( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + y_scale: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return out, y_scale + + +def rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl( + out: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + y_scale: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + import aiter as rocm_aiter + + residual_out = torch.empty_like(residual) + + rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant( + out, + input, + residual, + residual_out, + y_scale, + weight, + epsilon, + use_model_sensitive_rmsnorm=0, + ) + + return out, residual_out, y_scale + + +def rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake( + out: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + y_scale: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return out, torch.empty_like(residual), y_scale + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fused_dynamic_quant", + op_func=rocm_aiter_rmsnorm_fused_dynamic_quant_impl, + mutates_args=["out", "y_scale"], + fake_impl=rocm_aiter_rmsnorm_fused_dynamic_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fused_add_dynamic_quant", + op_func=rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl, + mutates_args=["out", "y_scale"], + fake_impl=rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +ROCM_AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default +ROCM_AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default # noqa: E501 + +ROCM_AITER_FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { + FusedRMSQuantKey( + kFp8DynamicTokenSym, + False, + ): torch.ops.vllm.rocm_aiter_rmsnorm_fused_dynamic_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, + True, + ): torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default, # noqa: E501 +} + + +class RMSNormAiterQuantPattern(RMSNormQuantPattern): + def __init__(self, epsilon, key): + self.epsilon = epsilon + self.quant_dtype = key.quant.dtype + + assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" + self.QUANT_OP = QUANT_OPS[key.quant] + + assert key in ROCM_AITER_FUSED_OPS, ( + f"unsupported fused aiter rmsnorm+quant op for {key}" + ) + self.FUSED_OP = ROCM_AITER_FUSED_OPS[key] + + +class RMSNormAiterDynamicQuantPattern(RMSNormAiterQuantPattern): + """AITER RMSNorm + Dynamic Quantization pattern.""" + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + super().__init__(epsilon, key) + + def register(self, pm_pass): + def pattern( + result: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + rms_out = ROCM_AITER_RMS_OP( + x=input, + weight=weight, + variance_epsilon=self.epsilon, + ) + + at = auto_functionalized( + self.QUANT_OP, result=result, input=rms_out, scale=scale, scale_ub=None + ) + + return at[1], at[2] + + def replacement( + result: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + return self.FUSED_OP( + out=result, + input=input, + weight=weight, + y_scale=scale, + epsilon=self.epsilon, + ) + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1), # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + ) + + +class FusedAddRMSNormAiterDynamicQuantPattern(RMSNormAiterQuantPattern): + """AITER RMSNorm Fused Add + Dynamic Quantization pattern.""" + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + super().__init__(epsilon, key) + + def register(self, pm_pass): + def pattern( + result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + rms_out, residual_out = ROCM_AITER_RMS_ADD_OP( + x=input, + residual=residual, + weight=weight, + variance_epsilon=self.epsilon, + ) + + at = auto_functionalized( + self.QUANT_OP, result=result, input=rms_out, scale=scale, scale_ub=None + ) + + return at[1], residual_out, at[2] + + def replacement( + result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + return self.FUSED_OP( + out=result, + input=input, + residual=residual, + weight=weight, + y_scale=scale, + epsilon=self.epsilon, + ) + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1), # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + ) + + +class RMSNormAiterQuantFusionPass(VllmPatternMatcherPass): + """ + This pass fuses aiter rms_norm & quant custom ops into a fused rms_norm_quant op. + It also supports aiter fused_add_rms_norm. + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="aiter_rmsnorm_quant_fusion_pass" + ) + + for epsilon in [1e-5, 1e-6]: + # Fuse aiter rms_norm + dynamic per-token fp8 quant + RMSNormAiterDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + + # Fuse aiter fused_add_rms_norm + dynamic per-token fp8 quant + FusedAddRMSNormAiterDynamicQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns + ) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self) -> Any: + return self.hash_source( + self, + RMSNormQuantPattern, + RMSNormAiterDynamicQuantPattern, + FusedAddRMSNormAiterDynamicQuantPattern, + )