55import torch
66
77import vllm .plugins
8- from vllm .compilation .fusion import RMSNormQuantFusionPass
8+ from vllm .compilation .fusion import FUSED_OPS , FusedRMSQuantKey , RMSNormQuantFusionPass
9+ from vllm .compilation .fx_utils import find_op_nodes
10+ from vllm .compilation .matcher_utils import QUANT_OPS
911from vllm .compilation .noop_elimination import NoOpEliminationPass
1012from vllm .compilation .post_cleanup import PostCleanupPass
1113from vllm .config import (
3335
3436FP8_DTYPE = current_platform .fp8_dtype ()
3537
38+ RMS_OP = torch .ops ._C .rms_norm .default
39+ RMS_ADD_OP = torch .ops ._C .fused_add_rms_norm .default
40+
3641
3742class TestModel (torch .nn .Module ):
3843 def __init__ (
@@ -50,7 +55,7 @@ def __init__(
5055 self .wscale = [torch .rand (1 , dtype = torch .float32 ) for _ in range (3 )]
5156 group_shape = GroupShape .PER_TENSOR if static else GroupShape .PER_TOKEN
5257 quant_scale = ScaleDesc (torch .float32 , static , group_shape )
53- self .key = QuantKey (dtype = FP8_DTYPE , scale = quant_scale , symmetric = True )
58+ self .quant_key = QuantKey (dtype = FP8_DTYPE , scale = quant_scale , symmetric = True )
5459 if static :
5560 self .scale = [torch .rand (1 , dtype = torch .float32 ) for _ in range (3 )]
5661 else :
@@ -93,6 +98,22 @@ def forward(self, x):
9398 y4 , resid = self .norm [3 ](x4 , resid ) # use resid here
9499 return y4
95100
101+ def ops_in_model_after (self ):
102+ return [
103+ FUSED_OPS [FusedRMSQuantKey (self .quant_key , True )],
104+ FUSED_OPS [FusedRMSQuantKey (self .quant_key , False )],
105+ ]
106+
107+ def ops_in_model_before (self ):
108+ return (
109+ [QUANT_OPS [self .quant_key ]]
110+ if self .enable_quant_fp8
111+ else [torch .ops .aten .reciprocal ]
112+ )
113+
114+ def ops_in_model_before_partial (self ):
115+ return [RMS_OP , RMS_ADD_OP ] if self .enable_rms_norm else [torch .ops .aten .rsqrt ]
116+
96117
97118@pytest .mark .parametrize ("dtype" , [torch .float16 ]) # , torch.bfloat16])
98119@pytest .mark .parametrize ("hidden_size" , [64 ])
@@ -164,3 +185,18 @@ def test_fusion_rmsnorm_quant(
164185 torch .testing .assert_close (result , result2 , atol = ATOL , rtol = RTOL )
165186
166187 assert fusion_pass .matched_count == 3
188+ backend .check_before_ops (model .ops_in_model_before ())
189+ backend .check_before_ops (
190+ model .ops_in_model_before_partial (), fully_replaced = False
191+ )
192+ backend .check_after_ops (model .ops_in_model_after ())
193+
194+ # If RMSNorm custom op is disabled (native/torch impl used),
195+ # there's a risk that the fused add doesn't get included in the
196+ # replacement and only the rms part gets fused with quant.
197+ # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
198+ if not enable_rms_norm :
199+ n_add_nodes = lambda g : sum (1 for _ in find_op_nodes (torch .ops .aten .add , g ))
200+ # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
201+ assert n_add_nodes (backend .graph_pre_pass ) == 7
202+ assert n_add_nodes (backend .graph_post_pass ) == 2
0 commit comments