55import  torch 
66
77import  vllm .envs  as  envs 
8- from  vllm .compilation .fix_functionalization  import  FixFunctionalizationPass 
98from  vllm .compilation .fusion  import  RMSNormQuantFusionPass 
10- from  vllm .compilation .fx_utils  import  find_auto_fn ,  find_auto_fn_maybe ,  is_func 
9+ from  vllm .compilation .fx_utils  import  find_auto_fn 
1110from  vllm .compilation .noop_elimination  import  NoOpEliminationPass 
1211from  vllm .compilation .post_cleanup  import  PostCleanupPass 
1312from  vllm .compilation .sequence_parallelism  import  SequenceParallelismPass 
2726    initialize_model_parallel ,
2827)
2928from  vllm .model_executor .layers .layernorm  import  RMSNorm 
29+ from  vllm .model_executor .layers .quantization .utils .quant_utils  import  GroupShape 
3030from  vllm .model_executor .layers .quantization .utils .w8a8_utils  import  Fp8LinearOp 
3131from  vllm .platforms  import  current_platform 
3232from  vllm .utils  import  update_environment_variables 
4343]
4444
4545
46- class  TestModel (torch .nn .Module ):
47-     def  __init__ (self , hidden_size = 16 , intermediate_size = 32 ):
46+ class  TestAllReduceRMSNormModel (torch .nn .Module ):
47+     def  __init__ (self , hidden_size = 16 , eps = 1e-6 ):
4848        super ().__init__ ()
4949        self .hidden_size  =  hidden_size 
50-         self .intermediate_size  =  intermediate_size 
51-         self .gate_proj  =  torch .nn .Parameter (
52-             torch .empty ((intermediate_size , hidden_size ))
53-         )
54-         self .norm  =  RMSNorm (intermediate_size , 1e-05 )
55-         # Initialize weights 
56-         torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
50+         self .eps  =  eps 
51+         self .norm  =  [RMSNorm (hidden_size , eps ) for  i  in  range (4 )]
52+         self .w  =  [torch .rand (hidden_size , hidden_size ) for  _  in  range (3 )]
5753
58-     def  forward (self , hidden_states , residual ):
59-         """ 
60-         Forward pass implementing the operations in the FX graph 
54+     def  forward (self , x ):
55+         z  =  torch .relu (x )
56+         x  =  resid  =  tensor_model_parallel_all_reduce (z )
57+         y  =  self .norm [0 ](x )
6158
62-         Args: 
63-             hidden_states: Input tensor 
64-             residual: Residual tensor from previous layer 
59+         z2  =  torch .mm (y , self .w [0 ])
60+         x2  =  tensor_model_parallel_all_reduce (z2 )
6561
66-         Returns: 
67-             Tuple containing the output tensor 
68-         """ 
69-         # Reshape input 
70-         view  =  hidden_states .reshape (- 1 , self .hidden_size )
62+         y2 , resid  =  self .norm [1 ](x2 , resid )
7163
72-         # matrix multiplication 
73-         permute  =  self .gate_proj .permute (1 , 0 )
74-         mm  =  torch .mm (view , permute )
64+         z3  =  torch .mm (y2 , self .w [1 ])
65+         x3  =  tensor_model_parallel_all_reduce (z3 )
7566
76-         # Tensor parallel all-reduce 
77-         all_reduce  =  tensor_model_parallel_all_reduce (mm )
67+         y3 , resid  =  self .norm [2 ](x3 , resid )
7868
79-         # layer normalization 
80-         norm_output ,  residual_output   =   self . norm ( all_reduce ,  residual )
69+         z4   =   torch . mm ( y3 ,  self . w [ 2 ]) 
70+         x4   =   tensor_model_parallel_all_reduce ( z4 )
8171
82-         return  norm_output , residual_output 
72+         y4 , resid  =  self .norm [3 ](x4 , resid )
73+         return  y4 
8374
8475    def  ops_in_model_before (self ):
8576        return  [torch .ops .vllm .all_reduce .default ]
8677
8778    def  ops_in_model_after (self ):
8879        return  [
89-             torch .ops .vllm .reduce_scatter .default ,
9080            torch .ops .vllm .all_gather .default ,
81+             torch .ops .vllm .reduce_scatter .default ,
9182        ]
9283
9384    def  ops_in_model (self ):
94-         return  [torch .ops ._C .fused_add_rms_norm .default ]
85+         if  RMSNorm .enabled ():
86+             return  [
87+                 torch .ops ._C .rms_norm .default ,
88+                 torch .ops ._C .fused_add_rms_norm .default ,
89+             ]
90+         else :
91+             return  []
9592
9693
97- class  TestQuantModel (torch .nn .Module ):
98-     def  __init__ (self , hidden_size = 16 , intermediate_size = 32 ):
94+ class  TestAllReduceRMSNormStaticQuantFP8Model (torch .nn .Module ):
95+     def  __init__ (self , hidden_size = 16 , eps = 1e-6 ):
9996        super ().__init__ ()
97+         self .vllm_config  =  get_current_vllm_config ()
10098        self .hidden_size  =  hidden_size 
101-         self .intermediate_size  =  intermediate_size 
102-         self .gate_proj  =  torch .nn .Parameter (
103-             torch .empty ((intermediate_size , hidden_size )), requires_grad = False 
99+         self .eps  =  eps 
100+         self .norm  =  [RMSNorm (hidden_size , eps ) for  i  in  range (4 )]
101+         self .wscale  =  [torch .rand (1 , dtype = torch .float32 ) for  _  in  range (3 )]
102+         self .w  =  [
103+             torch .rand (hidden_size , hidden_size )
104+             .to (dtype = current_platform .fp8_dtype ())
105+             .t ()
106+             for  _  in  range (3 )
107+         ]
108+ 
109+         self .fp8_linear  =  Fp8LinearOp (
110+             act_quant_static = True ,
111+             act_quant_group_shape = GroupShape .PER_TENSOR ,
104112        )
105-         self .norm  =  RMSNorm (intermediate_size , 1e-05 )
106-         # Initialize weights 
107-         torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
108- 
109-         self .fp8_linear  =  Fp8LinearOp (act_quant_static = True )
110- 
111-         self .scale  =  torch .rand (1 , dtype = torch .float32 )
112-         # Create a weight that is compatible with torch._scaled_mm, 
113-         # which expects a column-major layout. 
114-         self .w  =  torch .rand (hidden_size , intermediate_size ).to (dtype = FP8_DTYPE ).t ()
115-         self .wscale  =  torch .rand (1 , dtype = torch .float32 )
116- 
117-     def  forward (self , hidden_states , residual ):
118-         """ 
119-         Forward pass implementing the operations in the FX graph 
120- 
121-         Args: 
122-             hidden_states: Input tensor 
123-             residual: Residual tensor from previous layer 
124- 
125-         Returns: 
126-             Tuple containing the output tensor 
127-         """ 
128-         # Reshape input 
129-         view  =  hidden_states .reshape (- 1 , self .hidden_size )
130- 
131-         # matrix multiplication 
132-         permute  =  self .gate_proj .permute (1 , 0 )
133-         mm  =  torch .mm (view , permute )
134- 
135-         # Tensor parallel all-reduce 
136-         all_reduce  =  tensor_model_parallel_all_reduce (mm )
137- 
138-         # layer normalization 
139-         norm_output , residual_output  =  self .norm (all_reduce , residual )
140- 
141-         # scaled_mm with static input quantization 
142-         fp8_linear_result  =  self .fp8_linear .apply (
143-             norm_output ,
144-             self .w ,
145-             self .wscale ,
146-             input_scale = self .scale .to (norm_output .device ),
113+ 
114+         self .scale  =  [torch .rand (1 , dtype = torch .float32 ) for  _  in  range (3 )]
115+ 
116+     def  forward (self , hidden_states ):
117+         # avoid having graph input be an arg to a pattern directly 
118+         z  =  torch .relu (hidden_states )
119+         x  =  resid  =  tensor_model_parallel_all_reduce (z )
120+         y  =  self .norm [0 ](x )
121+ 
122+         z2  =  self .fp8_linear .apply (
123+             y , self .w [0 ], self .wscale [0 ], input_scale = self .scale [0 ]
147124        )
148125
149-         return  fp8_linear_result , residual_output 
126+         x2  =  tensor_model_parallel_all_reduce (z2 )
127+         y2 , resid  =  self .norm [1 ](x2 , resid )
150128
151-     def  ops_in_model_before (self ):
152-         ops_to_remove  =  [torch .ops .vllm .all_reduce .default ]  # Always removed by SP 
153-         # The following are only removed if fusion happens 
154-         config  =  get_current_vllm_config ()
155-         if  config .compilation_config .pass_config .enable_fusion :
156-             ops_to_remove .append (torch .ops ._C .fused_add_rms_norm .default )
157-             # Only check for static_scaled_fp8_quant if custom quant_fp8 is enabled 
158-             if  "+quant_fp8"  in  config .compilation_config .custom_ops :
159-                 ops_to_remove .append (torch .ops ._C .static_scaled_fp8_quant .default )
160-         return  ops_to_remove 
129+         z3  =  self .fp8_linear .apply (
130+             y2 , self .w [1 ], self .wscale [1 ], input_scale = self .scale [1 ]
131+         )
132+ 
133+         x3  =  tensor_model_parallel_all_reduce (z3 )
134+         y3 , resid  =  self .norm [2 ](x3 , resid )  # use resid here 
135+ 
136+         z4  =  self .fp8_linear .apply (
137+             y3 , self .w [2 ], self .wscale [2 ], input_scale = self .scale [2 ]
138+         )
139+         x4  =  tensor_model_parallel_all_reduce (z4 )
140+         y4 , resid  =  self .norm [3 ](x4 , resid )  # use resid here 
141+         return  y4 
161142
162143    def  ops_in_model_after (self ):
163-         ops_to_add  =  [
164-             torch .ops .vllm .reduce_scatter .default ,
144+         return  [
165145            torch .ops .vllm .all_gather .default ,
146+             torch .ops .vllm .reduce_scatter .default ,
147+         ]
148+ 
149+     def  ops_in_model_before (self ):
150+         return  [
151+             torch .ops .vllm .all_reduce .default ,
166152        ]
167-         # The following is only added if fusion happens and custom quant_fp8 is enabled 
168-         config  =  get_current_vllm_config ()
169-         if  (
170-             config .compilation_config .pass_config .enable_fusion 
171-             and  "+quant_fp8"  in  config .compilation_config .custom_ops 
172-         ):
173-             ops_to_add .append (torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default )
174-         return  ops_to_add 
175153
176154    def  ops_in_model (self ):
177-         config  =  get_current_vllm_config ()
178-         if  config .compilation_config .pass_config .enable_fusion :
179-             # If fusion happens with custom quant_fp8, the fused op is the one 
180-             # we check for (de)functionalization 
155+         if  self .vllm_config .compilation_config .pass_config .enable_fusion :
181156            return  [torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default ]
182-         else :
183-             # If no fusion or using native quant, the original ops are checked 
157+         elif  RMSNorm .enabled ():
184158            return  [
185159                torch .ops ._C .fused_add_rms_norm .default ,
186-                 # TODO  functionalization pass does not handle this yet 
187-                 # torch.ops._C.static_scaled_fp8_quant.default, 
188160            ]
161+         elif  self .fp8_linear .quant_fp8 .enabled ():
162+             return  [
163+                 torch .ops ._C .static_scaled_fp8_quant .default ,
164+             ]
165+         else :
166+             return  []
189167
190168
191169@multi_gpu_test (num_gpus = 2 ) 
192170@pytest .mark .parametrize ( 
193171    "test_model_cls, custom_ops" , 
194172    [ 
195-         (TestModel , "" ), 
196-         (TestQuantModel , "+quant_fp8" ), 
197-         (TestQuantModel , "-quant_fp8" ), 
173+         # (TestAllReduceRMSNormModel, "+rms_norm"),  
174+         # (TestAllReduceRMSNormModel, "-rms_norm"),  
175+         (TestAllReduceRMSNormStaticQuantFP8Model , "+rms_norm,+quant_fp8" ), 
176+         (TestAllReduceRMSNormStaticQuantFP8Model , "+rms_norm,-quant_fp8" ), 
177+         (TestAllReduceRMSNormStaticQuantFP8Model , "-rms_norm,+quant_fp8" ), 
178+         (TestAllReduceRMSNormStaticQuantFP8Model , "-rms_norm,-quant_fp8" ), 
198179    ], 
199180) 
200181@pytest .mark .parametrize ("batch_size" , [8 ]) 
201182@pytest .mark .parametrize ("seq_len" , [16 ]) 
202183@pytest .mark .parametrize ("hidden_size" , [16 ]) 
203- @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ]) 
204- @pytest .mark .parametrize ("enable_fusion" , [True , False ]) 
184+ @pytest .mark .parametrize ("dtype" , [torch .float16 ])  # , torch.bfloat16])  
185+ @pytest .mark .parametrize ("enable_fusion" , [True ])  # , False])  
186+ @pytest .mark .parametrize ("dynamic" , [False ])  # , True])  
205187@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE  not  in "cuda" ], reason = "Only test on CUDA" ) 
206188def  test_sequence_parallelism_pass (
207189    test_model_cls : type [torch .nn .Module ],
@@ -211,6 +193,7 @@ def test_sequence_parallelism_pass(
211193    hidden_size : int ,
212194    dtype : torch .dtype ,
213195    enable_fusion : bool ,
196+     dynamic : bool ,
214197):
215198    num_processes  =  2 
216199
@@ -228,6 +211,7 @@ def run_torch_spawn(fn, nprocs):
228211                hidden_size ,
229212                dtype ,
230213                enable_fusion ,
214+                 dynamic ,
231215            ),
232216            nprocs = nprocs ,
233217        )
@@ -245,6 +229,7 @@ def sequence_parallelism_pass_on_test_model(
245229    hidden_size : int ,
246230    dtype : torch .dtype ,
247231    enable_fusion : bool ,
232+     dynamic : bool ,
248233):
249234    current_platform .seed_everything (0 )
250235
@@ -295,7 +280,6 @@ def sequence_parallelism_pass_on_test_model(
295280    with  set_current_vllm_config (vllm_config ):
296281        noop_pass  =  NoOpEliminationPass (vllm_config )
297282        sequence_parallelism_pass  =  SequenceParallelismPass (vllm_config )
298-         func_pass  =  FixFunctionalizationPass (vllm_config )
299283        cleanup_pass  =  PostCleanupPass (vllm_config )
300284        assert  (
301285            sequence_parallelism_pass .compilation_config .splitting_ops 
@@ -316,38 +300,28 @@ def sequence_parallelism_pass_on_test_model(
316300
317301        passes_for_backend .append (cleanup_pass )
318302
319-         backend_no_func  =  TestBackend (* passes_for_backend )
320-         backend_func  =  TestBackend (* passes_for_backend , func_pass )
303+         backend  =  TestBackend (* passes_for_backend )
321304
322-         model  =  test_model_cls (hidden_size ,  hidden_size   *   2 )
305+         model  =  test_model_cls (hidden_size )
323306
324307        hidden_states  =  torch .randn ((batch_size  *  seq_len , hidden_size ), dtype = dtype )
325-         residual  =  torch .randn ((batch_size  *  seq_len , hidden_size ), dtype = dtype )
326308
327-         compiled_model_no_func  =  torch .compile (model , backend = backend_no_func )
328-         compiled_model_no_func (hidden_states , residual )
329-         compiled_model_func  =  torch .compile (model , backend = backend_func )
330-         compiled_model_func (hidden_states , residual )
309+         if  dynamic :
310+             torch ._dynamo .mark_dynamic (hidden_states , 0 )
311+ 
312+         compiled_model  =  torch .compile (model , backend = backend )
313+         compiled_model (hidden_states )
331314
332-         assert  sequence_parallelism_pass .matched_count  ==  1 
315+         assert  sequence_parallelism_pass .matched_count  ==  4 
333316
334317        # In pre-nodes, all reduce should be there, 
335318        # reduce scatter and all gather should not 
336-         backend_no_func .check_before_ops (model .ops_in_model_before ())
319+         backend .check_before_ops (model .ops_in_model_before ())
337320
338321        # In post-nodes, reduce scatter and all gather should be there, 
339322        # all reduce should not 
340-         backend_no_func .check_after_ops (model .ops_in_model_after ())
323+         backend .check_after_ops (model .ops_in_model_after ())
341324
342-         # check if the functionalization pass is applied 
325+         print ( backend . graph_post_pass ) 
343326        for  op  in  model .ops_in_model ():
344-             find_auto_fn (backend_no_func .graph_post_pass .nodes , op )
345-             assert  find_auto_fn_maybe (backend_func .graph_post_pass .nodes , op ) is  None 
346- 
347-         # make sure the ops were all de-functionalized 
348-         found  =  dict ()
349-         for  node  in  backend_func .graph_post_pass .nodes :
350-             for  op  in  model .ops_in_model ():
351-                 if  is_func (node , op ):
352-                     found [op ] =  True 
353-         assert  all (found [op ] for  op  in  model .ops_in_model ())
327+             find_auto_fn (backend .graph_post_pass .nodes , op )
0 commit comments