55import torch
66
77import vllm .envs as envs
8- from vllm .compilation .fix_functionalization import FixFunctionalizationPass
98from vllm .compilation .fusion import RMSNormQuantFusionPass
10- from vllm .compilation .fx_utils import find_auto_fn , find_auto_fn_maybe , is_func
9+ from vllm .compilation .fx_utils import find_auto_fn
1110from vllm .compilation .noop_elimination import NoOpEliminationPass
1211from vllm .compilation .post_cleanup import PostCleanupPass
1312from vllm .compilation .sequence_parallelism import SequenceParallelismPass
2726 initialize_model_parallel ,
2827)
2928from vllm .model_executor .layers .layernorm import RMSNorm
29+ from vllm .model_executor .layers .quantization .utils .quant_utils import GroupShape
3030from vllm .model_executor .layers .quantization .utils .w8a8_utils import Fp8LinearOp
3131from vllm .platforms import current_platform
3232from vllm .utils import update_environment_variables
4343]
4444
4545
46- class TestModel (torch .nn .Module ):
47- def __init__ (self , hidden_size = 16 , intermediate_size = 32 ):
46+ class TestAllReduceRMSNormModel (torch .nn .Module ):
47+ def __init__ (self , hidden_size = 16 , eps = 1e-6 ):
4848 super ().__init__ ()
4949 self .hidden_size = hidden_size
50- self .intermediate_size = intermediate_size
51- self .gate_proj = torch .nn .Parameter (
52- torch .empty ((intermediate_size , hidden_size ))
53- )
54- self .norm = RMSNorm (intermediate_size , 1e-05 )
55- # Initialize weights
56- torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
50+ self .eps = eps
51+ self .norm = [RMSNorm (hidden_size , eps ) for i in range (4 )]
52+ self .w = [torch .rand (hidden_size , hidden_size ) for _ in range (3 )]
5753
58- def forward (self , hidden_states , residual ):
59- """
60- Forward pass implementing the operations in the FX graph
54+ def forward (self , x ):
55+ z = torch .relu (x )
56+ x = resid = tensor_model_parallel_all_reduce (z )
57+ y = self .norm [0 ](x )
6158
62- Args:
63- hidden_states: Input tensor
64- residual: Residual tensor from previous layer
59+ z2 = torch .mm (y , self .w [0 ])
60+ x2 = tensor_model_parallel_all_reduce (z2 )
6561
66- Returns:
67- Tuple containing the output tensor
68- """
69- # Reshape input
70- view = hidden_states .reshape (- 1 , self .hidden_size )
62+ y2 , resid = self .norm [1 ](x2 , resid )
7163
72- # matrix multiplication
73- permute = self .gate_proj .permute (1 , 0 )
74- mm = torch .mm (view , permute )
64+ z3 = torch .mm (y2 , self .w [1 ])
65+ x3 = tensor_model_parallel_all_reduce (z3 )
7566
76- # Tensor parallel all-reduce
77- all_reduce = tensor_model_parallel_all_reduce (mm )
67+ y3 , resid = self .norm [2 ](x3 , resid )
7868
79- # layer normalization
80- norm_output , residual_output = self . norm ( all_reduce , residual )
69+ z4 = torch . mm ( y3 , self . w [ 2 ])
70+ x4 = tensor_model_parallel_all_reduce ( z4 )
8171
82- return norm_output , residual_output
72+ y4 , resid = self .norm [3 ](x4 , resid )
73+ return y4
8374
8475 def ops_in_model_before (self ):
8576 return [torch .ops .vllm .all_reduce .default ]
8677
8778 def ops_in_model_after (self ):
8879 return [
89- torch .ops .vllm .reduce_scatter .default ,
9080 torch .ops .vllm .all_gather .default ,
81+ torch .ops .vllm .reduce_scatter .default ,
9182 ]
9283
9384 def ops_in_model (self ):
94- return [torch .ops ._C .fused_add_rms_norm .default ]
85+ if RMSNorm .enabled ():
86+ return [
87+ torch .ops ._C .rms_norm .default ,
88+ torch .ops ._C .fused_add_rms_norm .default ,
89+ ]
90+ else :
91+ return []
9592
9693
97- class TestQuantModel (torch .nn .Module ):
98- def __init__ (self , hidden_size = 16 , intermediate_size = 32 ):
94+ class TestAllReduceRMSNormStaticQuantFP8Model (torch .nn .Module ):
95+ def __init__ (self , hidden_size = 16 , eps = 1e-6 ):
9996 super ().__init__ ()
97+ self .vllm_config = get_current_vllm_config ()
10098 self .hidden_size = hidden_size
101- self .intermediate_size = intermediate_size
102- self .gate_proj = torch .nn .Parameter (
103- torch .empty ((intermediate_size , hidden_size )), requires_grad = False
99+ self .eps = eps
100+ self .norm = [RMSNorm (hidden_size , eps ) for i in range (4 )]
101+ self .wscale = [torch .rand (1 , dtype = torch .float32 ) for _ in range (3 )]
102+ self .w = [
103+ torch .rand (hidden_size , hidden_size )
104+ .to (dtype = current_platform .fp8_dtype ())
105+ .t ()
106+ for _ in range (3 )
107+ ]
108+
109+ self .fp8_linear = Fp8LinearOp (
110+ act_quant_static = True ,
111+ act_quant_group_shape = GroupShape .PER_TENSOR ,
104112 )
105- self .norm = RMSNorm (intermediate_size , 1e-05 )
106- # Initialize weights
107- torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
108-
109- self .fp8_linear = Fp8LinearOp (act_quant_static = True )
110-
111- self .scale = torch .rand (1 , dtype = torch .float32 )
112- # Create a weight that is compatible with torch._scaled_mm,
113- # which expects a column-major layout.
114- self .w = torch .rand (hidden_size , intermediate_size ).to (dtype = FP8_DTYPE ).t ()
115- self .wscale = torch .rand (1 , dtype = torch .float32 )
116-
117- def forward (self , hidden_states , residual ):
118- """
119- Forward pass implementing the operations in the FX graph
120-
121- Args:
122- hidden_states: Input tensor
123- residual: Residual tensor from previous layer
124-
125- Returns:
126- Tuple containing the output tensor
127- """
128- # Reshape input
129- view = hidden_states .reshape (- 1 , self .hidden_size )
130-
131- # matrix multiplication
132- permute = self .gate_proj .permute (1 , 0 )
133- mm = torch .mm (view , permute )
134-
135- # Tensor parallel all-reduce
136- all_reduce = tensor_model_parallel_all_reduce (mm )
137-
138- # layer normalization
139- norm_output , residual_output = self .norm (all_reduce , residual )
140-
141- # scaled_mm with static input quantization
142- fp8_linear_result = self .fp8_linear .apply (
143- norm_output ,
144- self .w ,
145- self .wscale ,
146- input_scale = self .scale .to (norm_output .device ),
113+
114+ self .scale = [torch .rand (1 , dtype = torch .float32 ) for _ in range (3 )]
115+
116+ def forward (self , hidden_states ):
117+ # avoid having graph input be an arg to a pattern directly
118+ z = torch .relu (hidden_states )
119+ x = resid = tensor_model_parallel_all_reduce (z )
120+ y = self .norm [0 ](x )
121+
122+ z2 = self .fp8_linear .apply (
123+ y , self .w [0 ], self .wscale [0 ], input_scale = self .scale [0 ]
147124 )
148125
149- return fp8_linear_result , residual_output
126+ x2 = tensor_model_parallel_all_reduce (z2 )
127+ y2 , resid = self .norm [1 ](x2 , resid )
150128
151- def ops_in_model_before (self ):
152- ops_to_remove = [torch .ops .vllm .all_reduce .default ] # Always removed by SP
153- # The following are only removed if fusion happens
154- config = get_current_vllm_config ()
155- if config .compilation_config .pass_config .enable_fusion :
156- ops_to_remove .append (torch .ops ._C .fused_add_rms_norm .default )
157- # Only check for static_scaled_fp8_quant if custom quant_fp8 is enabled
158- if "+quant_fp8" in config .compilation_config .custom_ops :
159- ops_to_remove .append (torch .ops ._C .static_scaled_fp8_quant .default )
160- return ops_to_remove
129+ z3 = self .fp8_linear .apply (
130+ y2 , self .w [1 ], self .wscale [1 ], input_scale = self .scale [1 ]
131+ )
132+
133+ x3 = tensor_model_parallel_all_reduce (z3 )
134+ y3 , resid = self .norm [2 ](x3 , resid ) # use resid here
135+
136+ z4 = self .fp8_linear .apply (
137+ y3 , self .w [2 ], self .wscale [2 ], input_scale = self .scale [2 ]
138+ )
139+ x4 = tensor_model_parallel_all_reduce (z4 )
140+ y4 , resid = self .norm [3 ](x4 , resid ) # use resid here
141+ return y4
161142
162143 def ops_in_model_after (self ):
163- ops_to_add = [
164- torch .ops .vllm .reduce_scatter .default ,
144+ return [
165145 torch .ops .vllm .all_gather .default ,
146+ torch .ops .vllm .reduce_scatter .default ,
147+ ]
148+
149+ def ops_in_model_before (self ):
150+ return [
151+ torch .ops .vllm .all_reduce .default ,
166152 ]
167- # The following is only added if fusion happens and custom quant_fp8 is enabled
168- config = get_current_vllm_config ()
169- if (
170- config .compilation_config .pass_config .enable_fusion
171- and "+quant_fp8" in config .compilation_config .custom_ops
172- ):
173- ops_to_add .append (torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default )
174- return ops_to_add
175153
176154 def ops_in_model (self ):
177- config = get_current_vllm_config ()
178- if config .compilation_config .pass_config .enable_fusion :
179- # If fusion happens with custom quant_fp8, the fused op is the one
180- # we check for (de)functionalization
155+ if self .vllm_config .compilation_config .pass_config .enable_fusion :
181156 return [torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default ]
182- else :
183- # If no fusion or using native quant, the original ops are checked
157+ elif RMSNorm .enabled ():
184158 return [
185159 torch .ops ._C .fused_add_rms_norm .default ,
186- # TODO functionalization pass does not handle this yet
187- # torch.ops._C.static_scaled_fp8_quant.default,
188160 ]
161+ elif self .fp8_linear .quant_fp8 .enabled ():
162+ return [
163+ torch .ops ._C .static_scaled_fp8_quant .default ,
164+ ]
165+ else :
166+ return []
189167
190168
191169@multi_gpu_test (num_gpus = 2 )
192170@pytest .mark .parametrize (
193171 "test_model_cls, custom_ops" ,
194172 [
195- (TestModel , "" ),
196- (TestQuantModel , "+quant_fp8" ),
197- (TestQuantModel , "-quant_fp8" ),
173+ (TestAllReduceRMSNormModel , "+rms_norm" ),
174+ (TestAllReduceRMSNormModel , "-rms_norm" ),
175+ (TestAllReduceRMSNormStaticQuantFP8Model , "+rms_norm,+quant_fp8" ),
176+ (TestAllReduceRMSNormStaticQuantFP8Model , "+rms_norm,-quant_fp8" ),
177+ (TestAllReduceRMSNormStaticQuantFP8Model , "-rms_norm,+quant_fp8" ),
178+ (TestAllReduceRMSNormStaticQuantFP8Model , "-rms_norm,-quant_fp8" ),
198179 ],
199180)
200181@pytest .mark .parametrize ("batch_size" , [8 ])
201182@pytest .mark .parametrize ("seq_len" , [16 ])
202183@pytest .mark .parametrize ("hidden_size" , [16 ])
203184@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
204185@pytest .mark .parametrize ("enable_fusion" , [True , False ])
186+ @pytest .mark .parametrize ("dynamic" , [False , True ])
205187@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE not in ["cuda" ], reason = "Only test on CUDA" )
206188def test_sequence_parallelism_pass (
207189 test_model_cls : type [torch .nn .Module ],
@@ -211,6 +193,7 @@ def test_sequence_parallelism_pass(
211193 hidden_size : int ,
212194 dtype : torch .dtype ,
213195 enable_fusion : bool ,
196+ dynamic : bool ,
214197):
215198 num_processes = 2
216199
@@ -228,6 +211,7 @@ def run_torch_spawn(fn, nprocs):
228211 hidden_size ,
229212 dtype ,
230213 enable_fusion ,
214+ dynamic ,
231215 ),
232216 nprocs = nprocs ,
233217 )
@@ -245,6 +229,7 @@ def sequence_parallelism_pass_on_test_model(
245229 hidden_size : int ,
246230 dtype : torch .dtype ,
247231 enable_fusion : bool ,
232+ dynamic : bool ,
248233):
249234 current_platform .seed_everything (0 )
250235
@@ -295,7 +280,6 @@ def sequence_parallelism_pass_on_test_model(
295280 with set_current_vllm_config (vllm_config ):
296281 noop_pass = NoOpEliminationPass (vllm_config )
297282 sequence_parallelism_pass = SequenceParallelismPass (vllm_config )
298- func_pass = FixFunctionalizationPass (vllm_config )
299283 cleanup_pass = PostCleanupPass (vllm_config )
300284 assert (
301285 sequence_parallelism_pass .compilation_config .splitting_ops
@@ -316,38 +300,41 @@ def sequence_parallelism_pass_on_test_model(
316300
317301 passes_for_backend .append (cleanup_pass )
318302
319- backend_no_func = TestBackend (* passes_for_backend )
320- backend_func = TestBackend (* passes_for_backend , func_pass )
303+ backend = TestBackend (* passes_for_backend )
321304
322- model = test_model_cls (hidden_size , hidden_size * 2 )
305+ model = test_model_cls (hidden_size )
323306
324307 hidden_states = torch .randn ((batch_size * seq_len , hidden_size ), dtype = dtype )
325- residual = torch .randn ((batch_size * seq_len , hidden_size ), dtype = dtype )
326308
327- compiled_model_no_func = torch .compile (model , backend = backend_no_func )
328- compiled_model_no_func (hidden_states , residual )
329- compiled_model_func = torch .compile (model , backend = backend_func )
330- compiled_model_func (hidden_states , residual )
309+ if dynamic :
310+ torch ._dynamo .mark_dynamic (hidden_states , 0 )
331311
332- assert sequence_parallelism_pass .matched_count == 1
312+ compiled_model = torch .compile (model , backend = backend )
313+ compiled_model (hidden_states )
314+
315+ assert sequence_parallelism_pass .matched_count == 4
333316
334317 # In pre-nodes, all reduce should be there,
335318 # reduce scatter and all gather should not
336- backend_no_func .check_before_ops (model .ops_in_model_before ())
319+ pre_ops = [
320+ node .target
321+ for node in backend .graph_pre_pass .nodes
322+ if node .op == "call_function"
323+ ]
324+ for op in model .ops_in_model_before ():
325+ num_op = len ([pre_op for pre_op in pre_ops if pre_op == op ])
326+ assert num_op == 4
337327
338328 # In post-nodes, reduce scatter and all gather should be there,
339329 # all reduce should not
340- backend_no_func .check_after_ops (model .ops_in_model_after ())
330+ post_ops = [
331+ node .target
332+ for node in backend .graph_post_pass .nodes
333+ if node .op == "call_function"
334+ ]
335+ for op in model .ops_in_model_after ():
336+ num_op = len ([post_op for post_op in post_ops if post_op == op ])
337+ assert num_op == 4
341338
342- # check if the functionalization pass is applied
343339 for op in model .ops_in_model ():
344- find_auto_fn (backend_no_func .graph_post_pass .nodes , op )
345- assert find_auto_fn_maybe (backend_func .graph_post_pass .nodes , op ) is None
346-
347- # make sure the ops were all de-functionalized
348- found = dict ()
349- for node in backend_func .graph_post_pass .nodes :
350- for op in model .ops_in_model ():
351- if is_func (node , op ):
352- found [op ] = True
353- assert all (found [op ] for op in model .ops_in_model ())
340+ find_auto_fn (backend .graph_post_pass .nodes , op )
0 commit comments