Skip to content

Commit e718694

Browse files
authored
[Feat][aiter][ROCm] Add aiter rmsnorm and fp8 ptpc quant fusion (#735)
* add aiter rmsnorm and quant fusion kernel Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> * deprint Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> * disable aiter quant mm for compat Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> --------- Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Co-authored-by: kliuae-amd <kuanfu.liu@amd.com>
1 parent 7b1fb64 commit e718694

File tree

5 files changed

+548
-2
lines changed

5 files changed

+548
-2
lines changed

tests/compile/backend.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,36 @@ def check_after_ops(self, ops: Sequence[OpOverload]):
8888
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
8989
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
9090

91+
def check_before_fused_auto_custom_ops(
92+
self, ops: Sequence[tuple[OpOverload, bool]], fully_replaced=True
93+
):
94+
# currently only used for aiter custom ops that are
95+
# registered with mutable scheme directly on vllm namespace
96+
# while they are fused with auto_functionalized ops.
97+
98+
for op, target_op_only in ops:
99+
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass, target_op_only)))
100+
num_post = len(
101+
list(find_op_nodes(op, self.graph_post_pass, target_op_only))
102+
)
103+
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
104+
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
105+
if fully_replaced:
106+
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
107+
108+
def check_after_fused_auto_custom_ops(self, ops: Sequence[tuple[OpOverload, bool]]):
109+
# currently only used for aiter custom ops that
110+
# are registered with mutable scheme directly on vllm namespace
111+
# while they are fused with auto_functionalized ops.
112+
113+
for op, target_op_only in ops:
114+
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass, target_op_only)))
115+
num_post = len(
116+
list(find_op_nodes(op, self.graph_post_pass, target_op_only))
117+
)
118+
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
119+
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
120+
91121
def op_count(self, op: OpOverload, before=False) -> int:
92122
graph = self.graph_pre_pass if before else self.graph_post_pass
93123
return len(list(find_op_nodes(op, graph)))
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from collections.abc import Sequence
4+
5+
import pytest
6+
import torch
7+
from torch._ops import OpOverload
8+
9+
import vllm.plugins
10+
from vllm.compilation.fusion import (
11+
QUANT_OPS,
12+
FusedRMSQuantKey,
13+
)
14+
from vllm.compilation.noop_elimination import NoOpEliminationPass
15+
from vllm.compilation.post_cleanup import PostCleanupPass
16+
from vllm.compilation.rocm_aiter_rmsnorm_fusion import (
17+
ROCM_AITER_FUSED_OPS,
18+
RMSNormAiterQuantFusionPass,
19+
)
20+
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
21+
from vllm.model_executor.layers.layernorm import RMSNorm
22+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
23+
GroupShape,
24+
QuantKey,
25+
ScaleDesc,
26+
)
27+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
28+
Fp8LinearOp,
29+
maybe_create_device_identity,
30+
)
31+
from vllm.platforms import current_platform
32+
33+
from .backend import TestBackend
34+
35+
FP8_DTYPE = current_platform.fp8_dtype()
36+
37+
38+
class TestModel(torch.nn.Module):
39+
def __init__(
40+
self,
41+
hidden_size: int,
42+
eps: float,
43+
*args,
44+
**kwargs,
45+
):
46+
super().__init__(*args, **kwargs)
47+
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
48+
group_shape = GroupShape.PER_TOKEN
49+
# AITER RMSNorm fusion pass does not support static quantization at the moment.
50+
self.wscale = [
51+
torch.rand(size=(hidden_size, 1), dtype=torch.float32) for _ in range(2)
52+
]
53+
quant_scale = ScaleDesc(torch.float32, static=False, group_shape=group_shape)
54+
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
55+
56+
self.scale = [None for _ in range(2)]
57+
self.w = [
58+
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
59+
for _ in range(2)
60+
]
61+
62+
self.fp8_linear = Fp8LinearOp(
63+
act_quant_static=False,
64+
act_quant_group_shape=group_shape,
65+
)
66+
67+
def forward(self, x):
68+
resid = torch.sqrt(x)
69+
y = self.norm[0](x)
70+
71+
x2 = self.fp8_linear.apply(
72+
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
73+
)
74+
# make sure resid is used for replacement to work
75+
y2, resid = self.norm[1](x2, resid)
76+
77+
x3 = self.fp8_linear.apply(
78+
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
79+
)
80+
y3, resid = self.norm[2](x3, resid) # use resid here
81+
return y3
82+
83+
def ops_in_model_before(self) -> Sequence[tuple[OpOverload, bool]]:
84+
# find fp8 quant ops in the model before fusion using
85+
# its funcationalized version (without directly targeting the function).
86+
return [(QUANT_OPS[self.key], False)]
87+
88+
def ops_in_model_after(self) -> Sequence[tuple[OpOverload, bool]]:
89+
# find aiter rmsnorm fused ops in the model
90+
# after fusion by directly targeting the function.
91+
92+
return [
93+
(ROCM_AITER_FUSED_OPS[FusedRMSQuantKey(self.key, False)], True),
94+
(ROCM_AITER_FUSED_OPS[FusedRMSQuantKey(self.key, True)], True),
95+
]
96+
97+
98+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
99+
@pytest.mark.parametrize("hidden_size", [2048])
100+
@pytest.mark.parametrize("num_tokens", [257])
101+
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
102+
@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only test on ROCm")
103+
def test_fusion_rmsnorm_quant(
104+
dtype: torch.dtype,
105+
hidden_size: int,
106+
num_tokens: int,
107+
eps: float,
108+
monkeypatch: pytest.MonkeyPatch,
109+
):
110+
torch.set_default_device("cuda")
111+
torch.set_default_dtype(dtype)
112+
torch.manual_seed(1)
113+
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
114+
115+
vllm_config = VllmConfig(
116+
compilation_config=CompilationConfig(
117+
level=CompilationLevel.PIECEWISE,
118+
custom_ops=["+rms_norm", "+quant_fp8"],
119+
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
120+
)
121+
)
122+
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
123+
m.setenv("VLLM_ROCM_USE_AITER", "1")
124+
m.setenv("VLLM_ROCM_USE_AITER_LINEAR", "0")
125+
m.setenv("VLLM_ROCM_USE_AITER_RMSNORM", "1")
126+
127+
# Reshape pass is needed for the fusion pass to work
128+
noop_pass = NoOpEliminationPass(vllm_config)
129+
fusion_pass = RMSNormAiterQuantFusionPass(vllm_config)
130+
cleanup_pass = PostCleanupPass(vllm_config)
131+
132+
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
133+
model = TestModel(hidden_size, eps)
134+
135+
# First dimension dynamic
136+
x = torch.rand(num_tokens, hidden_size)
137+
torch._dynamo.mark_dynamic(x, 0)
138+
139+
result = model(x)
140+
141+
model2 = torch.compile(model, backend=backend)
142+
result2 = model2(x)
143+
144+
ATOL, RTOL = (1e-2, 1e-2)
145+
146+
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
147+
148+
assert fusion_pass.matched_count == 2
149+
150+
# In pre-nodes, fp8 quant should be there and fused kernels should not
151+
backend.check_before_fused_auto_custom_ops(model.ops_in_model_before())
152+
153+
# In post-nodes, fused kernels should be there and fp8 quant should not
154+
backend.check_after_fused_auto_custom_ops(model.ops_in_model_after())

vllm/compilation/fx_utils.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,29 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
6767

6868

6969
# An auto-functionalization-aware utility for finding nodes with a specific op
70-
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
71-
if not op._schema.is_mutable:
70+
def find_op_nodes(
71+
op: OpOverload, graph: fx.Graph, target_op_only: bool = False
72+
) -> Iterator[fx.Node]:
73+
"""
74+
Yields all nodes in the graph that call the given op.
75+
op (OpOverload):
76+
The operator overload to match within the FX graph.
77+
graph (fx.Graph):
78+
The FX graph to search for nodes calling the specified operator.
79+
target_op_only (bool):
80+
If True, only yields nodes that directly call the specified operator.
81+
If False, also yields nodes that call
82+
the operator via auto_functionalized.
83+
This is useful when `op`
84+
is a mutable or custom-registered operator
85+
that does not have an auto-functionalized version.
86+
"""
87+
88+
# op can be mutable by default, not using auto_functionalized.
89+
# op like aiter_rmsnorm_fused_dynamic_quant has mutable schema
90+
# by default directly registered on vllm namespace.
91+
# it is not auto functionalized.
92+
if not op._schema.is_mutable or target_op_only:
7293
yield from graph.find_nodes(op="call_function", target=op)
7394

7495
for n in graph.find_nodes(op="call_function", target=auto_functionalized):

vllm/compilation/pass_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
from .fusion import RMSNormQuantFusionPass
1919
from .fusion_attn import AttnFusionPass
2020

21+
if current_platform.is_rocm():
22+
from .rocm_aiter_rmsnorm_fusion import (
23+
RMSNormAiterQuantFusionPass,
24+
is_rocm_aiter_rmsnorm_enabled,
25+
)
26+
2127
if current_platform.is_cuda():
2228
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
2329

@@ -98,6 +104,9 @@ def configure(self, config: VllmConfig):
98104
self.passes += [AllReduceFusionPass(config)]
99105

100106
if self.pass_config.enable_fusion:
107+
if is_rocm_aiter_rmsnorm_enabled():
108+
self.passes += [RMSNormAiterQuantFusionPass(config)]
109+
101110
self.passes += [RMSNormQuantFusionPass(config)]
102111
self.passes += [ActivationQuantFusionPass(config)]
103112

0 commit comments

Comments
 (0)