3636With those tests, we can say at least, mtp would not break the 
3737correctess for the target model outputs. 
3838""" 
39+ import  os 
3940
4041import  pytest 
4142
4243from  .conftest  import  run_equality_correctness_test 
4344
44- # main model 
45- # NOTE vLLM use fp8 model, vllm-ascend use bf16 model 
46- MAIN_MODEL  =  "wemaster/deepseek_mtp_main_random_bf16" 
45+ # NOTE both main model and MTP are bfloat16 
46+ FLOAT_MODEL  =  "wemaster/deepseek_mtp_main_random_bf16" 
47+ 
48+ # NOTE main model is w8a8, MTP is bfloat16 
49+ QUANT_MODEL  =  "wemaster/deepseek_mtp_main_random_w8a8_part" 
50+ 
51+ # TODO when msmodelslim can quantify both main and MTP model 
52+ # This UT should use w8a8 fully weights. 
4753
4854# max. number of speculative tokens: this corresponds to 
4955# num_nextn_predict_layers in the config.json of the speculator model. 
5056MAX_SPEC_TOKENS  =  1 
5157
5258# precision 
5359PRECISION  =  "bfloat16" 
60+ os .environ ["VLLM_USE_MODELSCOPE" ] =  "True" 
5461
5562
63+ @pytest .mark .skipif (os .getenv ("VLLM_USE_V1" ) ==  "1" , 
64+                     reason = "mtp is not supported on v1" ) 
5665@pytest .mark .parametrize ( 
5766    "common_llm_kwargs" , 
5867    [{ 
6675        "dtype" : PRECISION , 
6776
6877        # Main model  
69-         "model_name" : MAIN_MODEL , 
78+         "model_name" : FLOAT_MODEL , 
7079
7180        # GPU memory utilization  
7281        "gpu_memory_utilization" : 0.85  
@@ -97,6 +106,7 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
97106                                  batch_size , output_len , seed )
98107
99108
109+ @pytest .mark .skipif (True , reason = "quant model is not ready." ) 
100110@pytest .mark .parametrize ( 
101111    "common_llm_kwargs" , 
102112    [{ 
@@ -110,7 +120,53 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
110120        "dtype" : PRECISION , 
111121
112122        # Main model  
113-         "model_name" : MAIN_MODEL , 
123+         "model_name" : QUANT_MODEL , 
124+ 
125+         # GPU memory utilization  
126+         "gpu_memory_utilization" : 0.85  
127+     }]) 
128+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}]) 
129+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}]) 
130+ @pytest .mark .parametrize ("test_llm_kwargs" , [ 
131+     { 
132+         "speculative_config" : { 
133+             "num_speculative_tokens" : MAX_SPEC_TOKENS , 
134+         }, 
135+     }, 
136+ ]) 
137+ @pytest .mark .parametrize ("output_len" , [ 
138+     128 , 
139+ ]) 
140+ @pytest .mark .parametrize ("batch_size" , [1 , 32 ]) 
141+ @pytest .mark .parametrize ("seed" , [1 ]) 
142+ def  test_mtp_e2e_quant_greedy_correctness (vllm_runner , common_llm_kwargs ,
143+                                           per_test_common_llm_kwargs ,
144+                                           baseline_llm_kwargs , test_llm_kwargs ,
145+                                           batch_size : int , output_len : int ,
146+                                           seed : int ):
147+ 
148+     run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
149+                                   per_test_common_llm_kwargs ,
150+                                   baseline_llm_kwargs , test_llm_kwargs ,
151+                                   batch_size , output_len , seed )
152+ 
153+ 
154+ @pytest .mark .skipif (os .getenv ("VLLM_USE_V1" ) ==  "1" , 
155+                     reason = "mtp is not supported on v1" ) 
156+ @pytest .mark .parametrize ( 
157+     "common_llm_kwargs" , 
158+     [{ 
159+         # Skip cuda graph recording for fast test.  
160+         "enforce_eager" : True , 
161+ 
162+         # Print spec metrics.  
163+         "disable_log_stats" : False , 
164+ 
165+         # Precision  
166+         "dtype" : PRECISION , 
167+ 
168+         # Main model  
169+         "model_name" : FLOAT_MODEL , 
114170
115171        # GPU memory utilization  
116172        "gpu_memory_utilization" : 0.85  
@@ -158,15 +214,13 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
158214        ["disable_logprobs" ])
159215
160216
161- @pytest .mark .skipif ( 
162-     True , 
163-     reason =  
164-     "Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode"  
165- ) 
217+ @pytest .mark .skipif (True , reason = "torchair ut can not clean mem." ) 
166218@pytest .mark .parametrize ( 
167219    "common_llm_kwargs" , 
168220    [{ 
169-         "enforce_eager" : False , 
221+         "additional_config" : { 
222+             'enable_graph_mode' : True , 
223+         }, 
170224
171225        # Print spec metrics.  
172226        "disable_log_stats" : False , 
@@ -175,7 +229,7 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
175229        "dtype" : PRECISION , 
176230
177231        # Main model  
178-         "model_name" : MAIN_MODEL , 
232+         "model_name" : FLOAT_MODEL , 
179233        "gpu_memory_utilization" : 0.85  
180234    }]) 
181235@pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}]) 
@@ -192,20 +246,64 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
192246]) 
193247@pytest .mark .parametrize ("batch_size" , [1 , 32 ]) 
194248@pytest .mark .parametrize ("seed" , [1 ]) 
195- def  test_mtp_e2e_greedy_correctness_cuda_graph (vllm_runner , common_llm_kwargs ,
196-                                                per_test_common_llm_kwargs ,
197-                                                baseline_llm_kwargs ,
198-                                                test_llm_kwargs ,
199-                                                batch_size : int ,
200-                                                output_len : int , seed : int ):
201-     """Verify greedy equality with cuda graph enabled and different 
202-     batch sizes.""" 
249+ def  test_mtp_e2e_greedy_correctness_torchair_graph (
250+         vllm_runner , common_llm_kwargs , per_test_common_llm_kwargs ,
251+         baseline_llm_kwargs , test_llm_kwargs , batch_size : int , output_len : int ,
252+         seed : int ):
253+     """Verify greedy equality with torchair graph enabled and different 
254+     batch sizes using bfloat16 weights.""" 
255+     run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
256+                                   per_test_common_llm_kwargs ,
257+                                   baseline_llm_kwargs , test_llm_kwargs ,
258+                                   batch_size , output_len , seed )
259+ 
260+ 
261+ @pytest .mark .skipif (True , reason = "quant model is not ready." ) 
262+ @pytest .mark .parametrize ( 
263+     "common_llm_kwargs" , 
264+     [{ 
265+         "additional_config" : { 
266+             'enable_graph_mode' : True , 
267+         }, 
268+ 
269+         # Print spec metrics.  
270+         "disable_log_stats" : False , 
271+ 
272+         # Precision  
273+         "dtype" : PRECISION , 
274+ 
275+         # Main model  
276+         "model_name" : QUANT_MODEL , 
277+         "gpu_memory_utilization" : 0.85  
278+     }]) 
279+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}]) 
280+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}]) 
281+ @pytest .mark .parametrize ("test_llm_kwargs" , [ 
282+     { 
283+         "speculative_config" : { 
284+             "num_speculative_tokens" : MAX_SPEC_TOKENS , 
285+         }, 
286+     }, 
287+ ]) 
288+ @pytest .mark .parametrize ("output_len" , [ 
289+     128 , 
290+ ]) 
291+ @pytest .mark .parametrize ("batch_size" , [1 , 32 ]) 
292+ @pytest .mark .parametrize ("seed" , [1 ]) 
293+ def  test_mtp_e2e_quant_greedy_correctness_torchair_graph (
294+         vllm_runner , common_llm_kwargs , per_test_common_llm_kwargs ,
295+         baseline_llm_kwargs , test_llm_kwargs , batch_size : int , output_len : int ,
296+         seed : int ):
297+     """Verify greedy equality with torchair graph enabled and different 
298+     batch sizes using quant weights.""" 
203299    run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
204300                                  per_test_common_llm_kwargs ,
205301                                  baseline_llm_kwargs , test_llm_kwargs ,
206302                                  batch_size , output_len , seed )
207303
208304
305+ @pytest .mark .skipif (os .getenv ("VLLM_USE_V1" ) ==  "1" , 
306+                     reason = "mtp is not supported on v1" ) 
209307@pytest .mark .parametrize ( 
210308    "common_llm_kwargs" , 
211309    [{ 
@@ -221,7 +319,7 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
221319        "dtype" : PRECISION , 
222320
223321        # Main model  
224-         "model_name" : MAIN_MODEL , 
322+         "model_name" : FLOAT_MODEL , 
225323
226324        # GPU memory utilization  
227325        "gpu_memory_utilization" : 0.9  
@@ -256,6 +354,8 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
256354                                  batch_size , output_len , seed )
257355
258356
357+ @pytest .mark .skipif (os .getenv ("VLLM_USE_V1" ) ==  "1" , 
358+                     reason = "mtp is not supported on v1" ) 
259359@pytest .mark .parametrize ( 
260360    "common_llm_kwargs" , 
261361    [{ 
@@ -266,7 +366,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
266366        "dtype" : PRECISION , 
267367
268368        # Main model  
269-         "model_name" : MAIN_MODEL , 
369+         "model_name" : FLOAT_MODEL , 
270370
271371        # GPU memory utilization  
272372        "gpu_memory_utilization" : 0.9  
@@ -305,6 +405,8 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
305405                                  batch_size , output_len , seed )
306406
307407
408+ @pytest .mark .skipif (os .getenv ("VLLM_USE_V1" ) ==  "1" , 
409+                     reason = "mtp is not supported on v1" ) 
308410@pytest .mark .parametrize ( 
309411    "common_llm_kwargs" , 
310412    [{ 
@@ -315,7 +417,7 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
315417        "dtype" : PRECISION , 
316418
317419        # Main model  
318-         "model_name" : MAIN_MODEL , 
420+         "model_name" : FLOAT_MODEL , 
319421
320422        # GPU memory utilization  
321423        "gpu_memory_utilization" : 0.9  
0 commit comments