@@ -99,7 +99,6 @@ def __init__(self, hidden_size=16, intermediate_size=32):
9999 super ().__init__ ()
100100 self .hidden_size = hidden_size
101101 self .intermediate_size = intermediate_size
102- self .vllm_config = get_current_vllm_config ()
103102 self .gate_proj = torch .nn .Parameter (
104103 torch .empty ((intermediate_size , hidden_size )), requires_grad = False
105104 )
@@ -152,41 +151,36 @@ def forward(self, hidden_states, residual):
152151 def ops_in_model_before (self ):
153152 ops_to_remove = [torch .ops .vllm .all_reduce .default ] # Always removed by SP
154153 # The following are only removed if fusion happens
155- if (
156- self .vllm_config
157- and self .vllm_config .compilation_config .pass_config .enable_fusion
158- ):
159- ops_to_remove .extend (
160- [
161- torch .ops ._C .fused_add_rms_norm .default ,
162- torch .ops ._C .static_scaled_fp8_quant .default ,
163- ]
164- )
154+ config = get_current_vllm_config ()
155+ if config .compilation_config .pass_config .enable_fusion :
156+ ops_to_remove .append (torch .ops ._C .fused_add_rms_norm .default )
157+ # Only check for static_scaled_fp8_quant if custom quant_fp8 is enabled
158+ if "+quant_fp8" in config .compilation_config .custom_ops :
159+ ops_to_remove .append (torch .ops ._C .static_scaled_fp8_quant .default )
165160 return ops_to_remove
166161
167162 def ops_in_model_after (self ):
168163 ops_to_add = [
169164 torch .ops .vllm .reduce_scatter .default ,
170165 torch .ops .vllm .all_gather .default ,
171166 ]
172- # The following is only added if fusion happens
167+ # The following is only added if fusion happens and custom quant_fp8 is enabled
168+ config = get_current_vllm_config ()
173169 if (
174- self . vllm_config
175- and self . vllm_config . compilation_config .pass_config . enable_fusion
170+ config . compilation_config . pass_config . enable_fusion
171+ and "+quant_fp8" in config . compilation_config .custom_ops
176172 ):
177173 ops_to_add .append (torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default )
178174 return ops_to_add
179175
180176 def ops_in_model (self ):
181- if (
182- self .vllm_config
183- and self .vllm_config .compilation_config .pass_config .enable_fusion
184- ):
185- # If fusion happens, the fused op is the one
177+ config = get_current_vllm_config ()
178+ if config .compilation_config .pass_config .enable_fusion :
179+ # If fusion happens with custom quant_fp8, the fused op is the one
186180 # we check for (de)functionalization
187181 return [torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default ]
188182 else :
189- # If no fusion, the original ops are checked
183+ # If no fusion or using native quant , the original ops are checked
190184 return [
191185 torch .ops ._C .fused_add_rms_norm .default ,
192186 # TODO functionalization pass does not handle this yet
@@ -195,7 +189,14 @@ def ops_in_model(self):
195189
196190
197191@multi_gpu_test (num_gpus = 2 )
198- @pytest .mark .parametrize ("test_model_cls" , [TestModel , TestQuantModel ])
192+ @pytest .mark .parametrize (
193+ "test_model_cls, custom_ops" ,
194+ [
195+ (TestModel , "" ),
196+ (TestQuantModel , "+quant_fp8" ),
197+ (TestQuantModel , "-quant_fp8" ),
198+ ],
199+ )
199200@pytest .mark .parametrize ("batch_size" , [8 ])
200201@pytest .mark .parametrize ("seq_len" , [16 ])
201202@pytest .mark .parametrize ("hidden_size" , [16 ])
@@ -204,6 +205,7 @@ def ops_in_model(self):
204205@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE not in ["cuda" ], reason = "Only test on CUDA" )
205206def test_sequence_parallelism_pass (
206207 test_model_cls : type [torch .nn .Module ],
208+ custom_ops : str ,
207209 batch_size : int ,
208210 seq_len : int ,
209211 hidden_size : int ,
@@ -220,6 +222,7 @@ def run_torch_spawn(fn, nprocs):
220222 args = (
221223 num_processes ,
222224 test_model_cls ,
225+ custom_ops ,
223226 batch_size ,
224227 seq_len ,
225228 hidden_size ,
@@ -236,6 +239,7 @@ def sequence_parallelism_pass_on_test_model(
236239 local_rank : int ,
237240 world_size : int ,
238241 test_model_cls : type [torch .nn .Module ],
242+ custom_ops : str ,
239243 batch_size : int ,
240244 seq_len : int ,
241245 hidden_size : int ,
@@ -264,12 +268,14 @@ def sequence_parallelism_pass_on_test_model(
264268 initialize_model_parallel (tensor_model_parallel_size = world_size )
265269
266270 # configure vllm config for SequenceParallelismPass
271+ custom_ops_list = custom_ops .split ("," ) if custom_ops else []
267272 compilation_config = CompilationConfig (
273+ custom_ops = custom_ops_list ,
268274 pass_config = PassConfig (
269275 enable_sequence_parallelism = True ,
270276 enable_fusion = enable_fusion ,
271277 enable_noop = True ,
272- )
278+ ),
273279 ) # NoOp needed for fusion
274280 device_config = DeviceConfig (device = torch .device ("cuda" ))
275281
0 commit comments