2626)
2727from vllm .model_executor .layers .layernorm import RMSNorm
2828from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
29+ Fp8LinearOp ,
2930 GroupShape ,
30- QuantFP8 ,
3131)
3232from vllm .platforms import current_platform
3333from vllm .utils import update_environment_variables
@@ -43,9 +43,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
4343 self .eps = eps
4444 self .norm = RMSNorm (hidden_size , eps )
4545
46- def forward (self , hidden_states , residual ):
47- view = hidden_states . reshape ( - 1 , self . hidden_size )
48- all_reduce = tensor_model_parallel_all_reduce (view )
46+ def forward (self , x ):
47+ z = torch . relu ( x )
48+ all_reduce = tensor_model_parallel_all_reduce (z )
4949 norm = self .norm (all_reduce )
5050 return norm
5151
@@ -63,9 +63,9 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
6363 self .eps = eps
6464 self .norm = RMSNorm (hidden_size , eps )
6565
66- def forward (self , hidden_states , residual ):
67- view = hidden_states . reshape ( - 1 , self . hidden_size )
68- all_reduce = tensor_model_parallel_all_reduce (view )
66+ def forward (self , hidden_states ):
67+ z = residual = torch . relu ( hidden_states )
68+ all_reduce = tensor_model_parallel_all_reduce (z )
6969 norm , res = self .norm (all_reduce , residual )
7070
7171 return norm , res
@@ -77,21 +77,53 @@ def ops_in_model_after(self):
7777 return [torch .ops .vllm .flashinfer_trtllm_fused_allreduce_norm .default ]
7878
7979
80- class TestAllReduceFusedAddRMSNormStaticQuantFP8Model (torch .nn .Module ):
80+ class TestAllReduceRMSNormStaticQuantFP8Model (torch .nn .Module ):
8181 def __init__ (self , hidden_size = 16 , token_num = 16 , eps = 1e-6 ):
8282 super ().__init__ ()
8383 self .hidden_size = hidden_size
8484 self .eps = eps
85- self .norm = RMSNorm (hidden_size , eps )
86- self .quant_fp8 = QuantFP8 (static = True , group_shape = GroupShape .PER_TENSOR )
87- self .scale = torch .rand (1 , dtype = torch .float32 )
85+ self .norm = [RMSNorm (hidden_size , eps ) for i in range (4 )]
86+ self .wscale = [torch .rand (1 , dtype = torch .float32 ) for _ in range (3 )]
87+ self .w = [
88+ torch .rand (hidden_size , hidden_size )
89+ .to (dtype = current_platform .fp8_dtype ())
90+ .t ()
91+ for _ in range (3 )
92+ ]
8893
89- def forward (self , hidden_states , residual ):
90- view = hidden_states .reshape (- 1 , self .hidden_size )
91- all_reduce = tensor_model_parallel_all_reduce (view )
92- norm_output , residual_output = self .norm (all_reduce , residual )
93- quant_out , _ = self .quant_fp8 (norm_output , self .scale )
94- return quant_out , residual_output
94+ self .fp8_linear = Fp8LinearOp (
95+ act_quant_static = True ,
96+ act_quant_group_shape = GroupShape .PER_TENSOR ,
97+ )
98+
99+ self .scale = [torch .rand (1 , dtype = torch .float32 ) for _ in range (3 )]
100+
101+ def forward (self , hidden_states ):
102+ # avoid having graph input be an arg to a pattern directly
103+ z = torch .relu (hidden_states )
104+ x = resid = tensor_model_parallel_all_reduce (z )
105+ y = self .norm [0 ](x )
106+
107+ z2 = self .fp8_linear .apply (
108+ y , self .w [0 ], self .wscale [0 ], input_scale = self .scale [0 ]
109+ )
110+
111+ x2 = tensor_model_parallel_all_reduce (z2 )
112+ y2 , resid = self .norm [1 ](x2 , resid )
113+
114+ z3 = self .fp8_linear .apply (
115+ y2 , self .w [1 ], self .wscale [1 ], input_scale = self .scale [1 ]
116+ )
117+
118+ x3 = tensor_model_parallel_all_reduce (z3 )
119+ y3 , resid = self .norm [2 ](x3 , resid ) # use resid here
120+
121+ z4 = self .fp8_linear .apply (
122+ y3 , self .w [2 ], self .wscale [2 ], input_scale = self .scale [2 ]
123+ )
124+ x4 = tensor_model_parallel_all_reduce (z4 )
125+ y4 , resid = self .norm [3 ](x4 , resid ) # use resid here
126+ return y4
95127
96128 def ops_in_model_after (self ):
97129 return [torch .ops .vllm .flashinfer_trtllm_fused_allreduce_norm .default ]
@@ -100,7 +132,7 @@ def ops_in_model_before(self):
100132 return [
101133 torch .ops .vllm .all_reduce .default ,
102134 torch .ops ._C .static_scaled_fp8_quant .default
103- if self .quant_fp8 .enabled ()
135+ if self .fp8_linear . quant_fp8 .enabled ()
104136 else torch .ops .aten .reciprocal .default ,
105137 ]
106138
@@ -120,11 +152,10 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
120152 rounded_n = round_up (scale_n , 4 )
121153 self .output_scale = torch .empty ((rounded_m , rounded_n // 4 ), dtype = torch .int32 )
122154
123- def forward (self , hidden_states , residual ):
124- view = hidden_states . reshape ( - 1 , self . hidden_size )
125- all_reduce = tensor_model_parallel_all_reduce (view )
155+ def forward (self , hidden_states ):
156+ z = residual = torch . relu ( hidden_states )
157+ all_reduce = tensor_model_parallel_all_reduce (z )
126158 norm_output , residual_output = self .norm (all_reduce , residual )
127- norm_output = norm_output .reshape (- 1 , norm_output .shape [- 1 ])
128159 torch .ops ._C .scaled_fp4_quant (
129160 self .output , norm_output , self .output_scale , self .scale
130161 )
@@ -146,8 +177,8 @@ def ops_in_model_before(self):
146177 [
147178 (TestAllReduceRMSNormModel , False ),
148179 (TestAllReduceFusedAddRMSNormModel , False ),
149- (TestAllReduceFusedAddRMSNormStaticQuantFP8Model , True ),
150- (TestAllReduceFusedAddRMSNormStaticQuantFP8Model , False ),
180+ (TestAllReduceRMSNormStaticQuantFP8Model , True ),
181+ (TestAllReduceRMSNormStaticQuantFP8Model , False ),
151182 (TestAllReduceFusedAddRMSNormStaticQuantFP4Model , False ),
152183 ],
153184)
@@ -269,12 +300,16 @@ def all_reduce_fusion_pass_on_test_model(
269300 model = test_model_cls (hidden_size , token_num )
270301
271302 hidden_states = torch .randn ((token_num , hidden_size ), requires_grad = False )
272- residual = torch .randn ((token_num , hidden_size ), requires_grad = False )
273303
274304 compiled_model = torch .compile (model , backend = backend )
275- compiled_model (hidden_states , residual )
305+ compiled_model (hidden_states )
276306
277- assert all_reduce_fusion_pass .matched_count == 1
307+ # TODO cleanup
308+ expected = 4 if test_model_cls is TestAllReduceRMSNormStaticQuantFP8Model else 1
309+
310+ assert all_reduce_fusion_pass .matched_count == expected , (
311+ f"{ all_reduce_fusion_pass .matched_count = } , { expected = } "
312+ )
278313 backend .check_before_ops (model .ops_in_model_before (), fully_replaced = False )
279314 backend .check_after_ops (model .ops_in_model_after ())
280315 del all_reduce_fusion_pass
0 commit comments