Skip to content

Commit c3264d8

Browse files
committed
Fix partial match rmsnorm+quant, fix allreduce+rmsnorm match
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent a1c7fdb commit c3264d8

File tree

4 files changed

+62
-14
lines changed

4 files changed

+62
-14
lines changed

tests/compile/test_fusion.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import torch
66

77
import vllm.plugins
8-
from vllm.compilation.fusion import RMSNormQuantFusionPass
8+
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
9+
from vllm.compilation.fx_utils import find_op_nodes
10+
from vllm.compilation.matcher_utils import QUANT_OPS
911
from vllm.compilation.noop_elimination import NoOpEliminationPass
1012
from vllm.compilation.post_cleanup import PostCleanupPass
1113
from vllm.config import (
@@ -33,6 +35,9 @@
3335

3436
FP8_DTYPE = current_platform.fp8_dtype()
3537

38+
RMS_OP = torch.ops._C.rms_norm.default
39+
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
40+
3641

3742
class TestModel(torch.nn.Module):
3843
def __init__(
@@ -50,7 +55,7 @@ def __init__(
5055
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
5156
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
5257
quant_scale = ScaleDesc(torch.float32, static, group_shape)
53-
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
58+
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
5459
if static:
5560
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
5661
else:
@@ -93,6 +98,22 @@ def forward(self, x):
9398
y4, resid = self.norm[3](x4, resid) # use resid here
9499
return y4
95100

101+
def ops_in_model_after(self):
102+
return [
103+
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
104+
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
105+
]
106+
107+
def ops_in_model_before(self):
108+
return (
109+
[QUANT_OPS[self.quant_key]]
110+
if self.enable_quant_fp8
111+
else [torch.ops.aten.reciprocal]
112+
)
113+
114+
def ops_in_model_before_partial(self):
115+
return [RMS_OP, RMS_ADD_OP] if self.enable_rms_norm else [torch.ops.aten.rsqrt]
116+
96117

97118
@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16])
98119
@pytest.mark.parametrize("hidden_size", [64])
@@ -164,3 +185,18 @@ def test_fusion_rmsnorm_quant(
164185
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
165186

166187
assert fusion_pass.matched_count == 3
188+
backend.check_before_ops(model.ops_in_model_before())
189+
backend.check_before_ops(
190+
model.ops_in_model_before_partial(), fully_replaced=False
191+
)
192+
backend.check_after_ops(model.ops_in_model_after())
193+
194+
# If RMSNorm custom op is disabled (native/torch impl used),
195+
# there's a risk that the fused add doesn't get included in the
196+
# replacement and only the rms part gets fused with quant.
197+
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
198+
if not enable_rms_norm:
199+
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
200+
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
201+
assert n_add_nodes(backend.graph_pre_pass) == 7
202+
assert n_add_nodes(backend.graph_post_pass) == 2

vllm/compilation/fusion.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey):
9494
self.epsilon = epsilon
9595
self.quant_dtype = key.quant.dtype
9696

97-
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
98-
self.QUANT_OP = QUANT_OPS[key.quant]
99-
10097
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
10198
self.FUSED_OP = FUSED_OPS[key]
10299

@@ -334,23 +331,25 @@ def __init__(self, config: VllmConfig):
334331
pass_name="rmsnorm_quant_fusion_pass"
335332
)
336333

334+
# Make sure fused add patterns are before simple rms norm,
335+
# as the latter is a subset of the former in torch ops
337336
for epsilon in [1e-5, 1e-6]:
338-
# Fuse rms_norm + static fp8 quant
339-
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
340-
341337
# Fuse fused_add_rms_norm + static fp8 quant
342338
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
343339
self.patterns
344340
)
345341

346-
# Fuse rms_norm + dynamic per-token fp8 quant
347-
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
342+
# Fuse rms_norm + static fp8 quant
343+
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
348344

349345
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
350346
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
351347
self.patterns
352348
)
353349

350+
# Fuse rms_norm + dynamic per-token fp8 quant
351+
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
352+
354353
self.dump_patterns(config, self.patterns)
355354

356355
@VllmInductorPass.time_and_log

vllm/compilation/fx_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
import operator
55
from collections.abc import Iterable, Iterator
6-
from typing import Optional
6+
from typing import Optional, Union
77

88
from torch import fx
99
from torch._higher_order_ops.auto_functionalize import auto_functionalized
10-
from torch._ops import OpOverload
10+
from torch._ops import OpOverload, OpOverloadPacket
1111

1212

1313
def is_func(node: fx.Node, target) -> bool:
@@ -67,7 +67,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
6767

6868

6969
# An auto-functionalization-aware utility for finding nodes with a specific op
70-
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
70+
# Also handles op overload packets and finds all overloads
71+
def find_op_nodes(
72+
op: Union[OpOverload, OpOverloadPacket], graph: fx.Graph
73+
) -> Iterator[fx.Node]:
74+
if isinstance(op, OpOverloadPacket):
75+
for overload in op.overloads():
76+
overload_op = getattr(op, overload)
77+
yield from find_op_nodes(overload_op, graph)
78+
return
79+
80+
assert isinstance(op, OpOverload)
7181
if not op._schema.is_mutable:
7282
yield from graph.find_nodes(op="call_function", target=op)
7383

vllm/model_executor/layers/layernorm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,10 @@ def forward_static(
195195
orig_dtype = x.dtype
196196
x = x.to(torch.float32)
197197
if residual is not None:
198-
x = x + residual.to(torch.float32)
198+
# residual promoted f16->f32 automatically,
199+
# otherwise Inductor eliminates the casts to and from f16,
200+
# increasing memory usage (and complicating pattern matching)
201+
x = x + residual
199202
residual = x.to(orig_dtype)
200203

201204
if x.shape[-1] != hidden_size:

0 commit comments

Comments
 (0)