66import torch
77
88import vllm .envs as envs
9+ from vllm import _custom_ops as ops
910from vllm .compilation .collective_fusion import AllReduceFusionPass
1011from vllm .compilation .fix_functionalization import FixFunctionalizationPass
12+ from vllm .compilation .fx_utils import find_op_nodes
1113from vllm .compilation .noop_elimination import NoOpEliminationPass
1214from vllm .config import (CompilationConfig , CompilationLevel , DeviceConfig ,
1315 ModelConfig , PassConfig , VllmConfig ,
14- set_current_vllm_config )
16+ get_current_vllm_config , set_current_vllm_config )
1517from vllm .distributed import tensor_model_parallel_all_reduce
1618from vllm .distributed .parallel_state import (init_distributed_environment ,
1719 initialize_model_parallel )
2527from .backend import TestBackend
2628
2729
30+ def finisher (hidden_states ):
31+ custom_ops = get_current_vllm_config ().compilation_config .custom_ops
32+ if not custom_ops or "+quant_fp8" not in custom_ops :
33+ # Hack: use dynamic fp8 quantization to
34+ # suppress torch.compile optimizations
35+ # that prevent pattern matching
36+ return ops .scaled_fp8_quant (hidden_states )
37+ else :
38+ return hidden_states
39+
40+
2841class TestAllReduceRMSNormModel (torch .nn .Module ):
42+ pattern_code = 1
2943
3044 def __init__ (self , hidden_size = 16 , token_num = 16 , eps = 1e-6 ):
3145 super ().__init__ ()
@@ -34,10 +48,14 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
3448 self .norm = RMSNorm (hidden_size , eps )
3549
3650 def forward (self , hidden_states , residual ):
37- view = hidden_states .reshape (- 1 , self .hidden_size )
38- all_reduce = tensor_model_parallel_all_reduce (view )
39- norm = self .norm (all_reduce )
40- return norm
51+ # view = hidden_states.reshape(-1, self.hidden_size)
52+ all_reduce = tensor_model_parallel_all_reduce (hidden_states )
53+
54+ hidden_states = self .norm (all_reduce )
55+
56+ hidden_states = finisher (hidden_states )
57+
58+ return hidden_states
4159
4260 def ops_in_model_before (self ):
4361 return [torch .ops .vllm .all_reduce .default ]
@@ -47,6 +65,7 @@ def ops_in_model_after(self):
4765
4866
4967class TestAllReduceFusedAddRMSNormModel (torch .nn .Module ):
68+ pattern_code = 1
5069
5170 def __init__ (self , hidden_size = 16 , token_num = 16 , eps = 1e-6 ):
5271 super ().__init__ ()
@@ -57,35 +76,54 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
5776 def forward (self , hidden_states , residual ):
5877 view = hidden_states .reshape (- 1 , self .hidden_size )
5978 all_reduce = tensor_model_parallel_all_reduce (view )
60- norm , _ = self .norm (all_reduce , residual )
61- return norm
62-
63- def ops_in_model_before (self ):
64- return [torch .ops .vllm .all_reduce .default ]
79+ hidden_states , residual = self .norm (all_reduce , residual )
80+ # Hack: use dynamic fp8 quantization to
81+ # suppress torch.compile optimizations
82+ # that prevent pattern matching
83+ hidden_states = finisher (hidden_states )
84+ return hidden_states , residual
6585
6686 def ops_in_model_after (self ):
6787 return [torch .ops .vllm .flashinfer_trtllm_fused_allreduce_norm .default ]
6888
89+ def ops_in_model_before (self ):
90+ return [
91+ torch .ops .vllm .all_reduce .default ,
92+ ]
93+
6994
7095class TestAllReduceFusedAddRMSNormStaticQuantFP8Model (torch .nn .Module ):
96+ pattern_code = 2
7197
7298 def __init__ (self , hidden_size = 16 , token_num = 16 , eps = 1e-6 ):
7399 super ().__init__ ()
74100 self .hidden_size = hidden_size
75101 self .eps = eps
76102 self .norm = RMSNorm (hidden_size , eps )
77- self .quant_fp8 = QuantFP8 (static = True ,
78- group_shape = GroupShape .PER_TENSOR )
79- self .scale = torch .rand (1 , dtype = torch .float32 )
80103 self .output = torch .empty ((token_num , hidden_size ),
81- dtype = torch .float32 )
104+ dtype = current_platform .fp8_dtype ())
105+
106+ def _quant_fp8_wrapper (x , scale ):
107+ torch .ops ._C .static_scaled_fp8_quant (self .output , x , scale )
108+ return self .output , scale
109+
110+ vllm_config = get_current_vllm_config ()
111+ if "+quant_fp8" in vllm_config .compilation_config .custom_ops :
112+ # Need to use static_scaled_fp8_quant instead of QuantFP8
113+ # due to failure in TestBackend with copying graph
114+ self .quant_fp8 = _quant_fp8_wrapper
115+ else :
116+ self .quant_fp8 = QuantFP8 (static = True ,
117+ group_shape = GroupShape .PER_TENSOR )
118+ self .scale = torch .rand (1 , dtype = torch .float32 )
82119
83120 def forward (self , hidden_states , residual ):
84121 view = hidden_states .reshape (- 1 , self .hidden_size )
85122 all_reduce = tensor_model_parallel_all_reduce (view )
86123 norm_output , residual_output = self .norm (all_reduce , residual )
87- self .output , _ = self .quant_fp8 (norm_output , self .scale )
88- return self .output , residual_output
124+ output , _ = self .quant_fp8 (norm_output , self .scale )
125+ hidden_states = finisher (output .to (hidden_states .dtype ))
126+ return hidden_states , residual_output
89127
90128 def ops_in_model_after (self ):
91129 return [torch .ops .vllm .flashinfer_trtllm_fused_allreduce_norm .default ]
@@ -97,6 +135,7 @@ def ops_in_model_before(self):
97135
98136
99137class TestAllReduceFusedAddRMSNormStaticQuantFP4Model (torch .nn .Module ):
138+ pattern_code = 3
100139
101140 def __init__ (self , hidden_size = 16 , token_num = 16 , eps = 1e-6 ):
102141 super ().__init__ ()
@@ -143,6 +182,9 @@ def ops_in_model_before(self):
143182 # TODO: Enable with torch==2.8.0
144183 # TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
145184 ])
185+ @pytest .mark .parametrize (
186+ "custom_ops" ,
187+ [[], ["+rms_norm" ], ["+quant_fp8" ], ["+rms_norm" , "+quant_fp8" ]])
146188@pytest .mark .parametrize ("batch_size" , [8 ])
147189@pytest .mark .parametrize ("seq_len" , [8 ])
148190@pytest .mark .parametrize ("hidden_size" , [16 ])
@@ -155,19 +197,23 @@ def ops_in_model_before(self):
155197 reason = "flashinfer is not found or flashinfer "
156198 "is not compiled with trtllm_allreduce_fusion" )
157199def test_all_reduce_fusion_pass_replace (test_model : torch .nn .Module ,
158- batch_size : int , seq_len : int ,
159- hidden_size : int , dtype : torch .dtype ):
200+ custom_ops : list [str ], batch_size : int ,
201+ seq_len : int , hidden_size : int ,
202+ dtype : torch .dtype ):
160203 num_processes = 2
161204 if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
162205 and not current_platform .has_device_capability (100 )):
163206 pytest .skip ("Skip as nvfp4 is only supported on "
164207 "devices with compute capability 10.0 (Blackwell)" )
208+ if (test_model != TestAllReduceFusedAddRMSNormStaticQuantFP8Model
209+ and ("+quant_fp8" in custom_ops )):
210+ pytest .skip ()
165211
166212 def run_torch_spawn (fn , nprocs ):
167213 torch .multiprocessing .spawn (fn ,
168214 args = (num_processes , test_model ,
169215 batch_size , seq_len , hidden_size ,
170- dtype ),
216+ dtype , custom_ops ),
171217 nprocs = nprocs )
172218
173219 run_torch_spawn (all_reduce_fusion_pass_on_test_model , num_processes )
@@ -176,7 +222,8 @@ def run_torch_spawn(fn, nprocs):
176222def all_reduce_fusion_pass_on_test_model (local_rank : int , world_size : int ,
177223 test_model_cls : torch .nn .Module ,
178224 batch_size : int , seq_len : int ,
179- hidden_size : int , dtype : torch .dtype ):
225+ hidden_size : int , dtype : torch .dtype ,
226+ custom_ops : list [str ]):
180227 current_platform .seed_everything (0 )
181228
182229 device = torch .device (f"cuda:{ local_rank } " )
@@ -196,10 +243,9 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
196243 initialize_model_parallel (tensor_model_parallel_size = world_size )
197244
198245 vllm_config = VllmConfig (compilation_config = CompilationConfig (
199- level = CompilationLevel .PIECEWISE ,
200- custom_ops = ["+rms_norm" , "+quant_fp8" ]))
246+ level = CompilationLevel .PIECEWISE , custom_ops = custom_ops ))
201247 vllm_config .compilation_config .pass_config = PassConfig (
202- enable_fi_allreduce_fusion = True , enable_noop = False )
248+ enable_fi_allreduce_fusion = True , enable_noop = True )
203249 vllm_config .device_config = DeviceConfig (device = torch .device ("cuda" ))
204250
205251 # this is a fake model name to construct the model config
@@ -221,11 +267,18 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
221267
222268 hidden_states = torch .randn ((token_num , hidden_size ),
223269 requires_grad = False )
224- residual = torch .randn ((token_num , hidden_size ), requires_grad = False )
270+ residual = torch .randn ((token_num , hidden_size ),
271+ dtype = torch .float32 ,
272+ requires_grad = False )
225273
226274 compiled_model = torch .compile (model , backend = backend )
227275 compiled_model (hidden_states , residual )
228276
229277 backend .check_before_ops (model .ops_in_model_before (),
230278 fully_replaced = False )
231279 backend .check_after_ops (model .ops_in_model_after ())
280+ for node in find_op_nodes (
281+ torch .ops .vllm .flashinfer_trtllm_fused_allreduce_norm .default ,
282+ backend .graph_post_pass ):
283+ assert (
284+ node .kwargs .get ("pattern_code" ) == test_model_cls .pattern_code )
0 commit comments