3636With those tests, we can say at least, mtp would not break the
3737correctess for the target model outputs.
3838"""
39+ import os
3940
4041import pytest
4142
4243from .conftest import run_equality_correctness_test
4344
44- # main model
45- # NOTE vLLM use fp8 model, vllm-ascend use bf16 model
46- MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
45+ # NOTE both main model and MTP are bfloat16
46+ FLOAT_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
47+
48+ # NOTE main model is w8a8, MTP is bfloat16
49+ QUANT_MODEL = "wemaster/deepseek_mtp_main_random_w8a8_part"
50+
51+ # TODO when msmodelslim can quantify both main and MTP model
52+ # This UT should use w8a8 fully weights.
4753
4854# max. number of speculative tokens: this corresponds to
4955# num_nextn_predict_layers in the config.json of the speculator model.
5056MAX_SPEC_TOKENS = 1
5157
5258# precision
5359PRECISION = "bfloat16"
60+ os .environ ["VLLM_USE_MODELSCOPE" ] = "True"
5461
5562
63+ @pytest .mark .skipif (os .getenv ("VLLM_USE_V1" ) == "1" ,
64+ reason = "mtp is not supported on v1" )
5665@pytest .mark .parametrize (
5766 "common_llm_kwargs" ,
5867 [{
6675 "dtype" : PRECISION ,
6776
6877 # Main model
69- "model_name" : MAIN_MODEL ,
78+ "model_name" : FLOAT_MODEL ,
7079
7180 # GPU memory utilization
7281 "gpu_memory_utilization" : 0.85
@@ -97,6 +106,7 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
97106 batch_size , output_len , seed )
98107
99108
109+ @pytest .mark .skipif (True , reason = "quant model is not ready." )
100110@pytest .mark .parametrize (
101111 "common_llm_kwargs" ,
102112 [{
@@ -110,7 +120,53 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
110120 "dtype" : PRECISION ,
111121
112122 # Main model
113- "model_name" : MAIN_MODEL ,
123+ "model_name" : QUANT_MODEL ,
124+
125+ # GPU memory utilization
126+ "gpu_memory_utilization" : 0.85
127+ }])
128+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
129+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
130+ @pytest .mark .parametrize ("test_llm_kwargs" , [
131+ {
132+ "speculative_config" : {
133+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
134+ },
135+ },
136+ ])
137+ @pytest .mark .parametrize ("output_len" , [
138+ 128 ,
139+ ])
140+ @pytest .mark .parametrize ("batch_size" , [1 , 32 ])
141+ @pytest .mark .parametrize ("seed" , [1 ])
142+ def test_mtp_e2e_quant_greedy_correctness (vllm_runner , common_llm_kwargs ,
143+ per_test_common_llm_kwargs ,
144+ baseline_llm_kwargs , test_llm_kwargs ,
145+ batch_size : int , output_len : int ,
146+ seed : int ):
147+
148+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
149+ per_test_common_llm_kwargs ,
150+ baseline_llm_kwargs , test_llm_kwargs ,
151+ batch_size , output_len , seed )
152+
153+
154+ @pytest .mark .skipif (os .getenv ("VLLM_USE_V1" ) == "1" ,
155+ reason = "mtp is not supported on v1" )
156+ @pytest .mark .parametrize (
157+ "common_llm_kwargs" ,
158+ [{
159+ # Skip cuda graph recording for fast test.
160+ "enforce_eager" : True ,
161+
162+ # Print spec metrics.
163+ "disable_log_stats" : False ,
164+
165+ # Precision
166+ "dtype" : PRECISION ,
167+
168+ # Main model
169+ "model_name" : FLOAT_MODEL ,
114170
115171 # GPU memory utilization
116172 "gpu_memory_utilization" : 0.85
@@ -158,15 +214,13 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
158214 ["disable_logprobs" ])
159215
160216
161- @pytest .mark .skipif (
162- True ,
163- reason =
164- "Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode"
165- )
217+ @pytest .mark .skipif (True , reason = "torchair ut can not clean mem." )
166218@pytest .mark .parametrize (
167219 "common_llm_kwargs" ,
168220 [{
169- "enforce_eager" : False ,
221+ "additional_config" : {
222+ 'enable_graph_mode' : True ,
223+ },
170224
171225 # Print spec metrics.
172226 "disable_log_stats" : False ,
@@ -175,7 +229,7 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
175229 "dtype" : PRECISION ,
176230
177231 # Main model
178- "model_name" : MAIN_MODEL ,
232+ "model_name" : FLOAT_MODEL ,
179233 "gpu_memory_utilization" : 0.85
180234 }])
181235@pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
@@ -192,20 +246,64 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
192246])
193247@pytest .mark .parametrize ("batch_size" , [1 , 32 ])
194248@pytest .mark .parametrize ("seed" , [1 ])
195- def test_mtp_e2e_greedy_correctness_cuda_graph (vllm_runner , common_llm_kwargs ,
196- per_test_common_llm_kwargs ,
197- baseline_llm_kwargs ,
198- test_llm_kwargs ,
199- batch_size : int ,
200- output_len : int , seed : int ):
201- """Verify greedy equality with cuda graph enabled and different
202- batch sizes."""
249+ def test_mtp_e2e_greedy_correctness_torchair_graph (
250+ vllm_runner , common_llm_kwargs , per_test_common_llm_kwargs ,
251+ baseline_llm_kwargs , test_llm_kwargs , batch_size : int , output_len : int ,
252+ seed : int ):
253+ """Verify greedy equality with torchair graph enabled and different
254+ batch sizes using bfloat16 weights."""
255+ run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
256+ per_test_common_llm_kwargs ,
257+ baseline_llm_kwargs , test_llm_kwargs ,
258+ batch_size , output_len , seed )
259+
260+
261+ @pytest .mark .skipif (True , reason = "quant model is not ready." )
262+ @pytest .mark .parametrize (
263+ "common_llm_kwargs" ,
264+ [{
265+ "additional_config" : {
266+ 'enable_graph_mode' : True ,
267+ },
268+
269+ # Print spec metrics.
270+ "disable_log_stats" : False ,
271+
272+ # Precision
273+ "dtype" : PRECISION ,
274+
275+ # Main model
276+ "model_name" : QUANT_MODEL ,
277+ "gpu_memory_utilization" : 0.85
278+ }])
279+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
280+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
281+ @pytest .mark .parametrize ("test_llm_kwargs" , [
282+ {
283+ "speculative_config" : {
284+ "num_speculative_tokens" : MAX_SPEC_TOKENS ,
285+ },
286+ },
287+ ])
288+ @pytest .mark .parametrize ("output_len" , [
289+ 128 ,
290+ ])
291+ @pytest .mark .parametrize ("batch_size" , [1 , 32 ])
292+ @pytest .mark .parametrize ("seed" , [1 ])
293+ def test_mtp_e2e_quant_greedy_correctness_torchair_graph (
294+ vllm_runner , common_llm_kwargs , per_test_common_llm_kwargs ,
295+ baseline_llm_kwargs , test_llm_kwargs , batch_size : int , output_len : int ,
296+ seed : int ):
297+ """Verify greedy equality with torchair graph enabled and different
298+ batch sizes using quant weights."""
203299 run_equality_correctness_test (vllm_runner , common_llm_kwargs ,
204300 per_test_common_llm_kwargs ,
205301 baseline_llm_kwargs , test_llm_kwargs ,
206302 batch_size , output_len , seed )
207303
208304
305+ @pytest .mark .skipif (os .getenv ("VLLM_USE_V1" ) == "1" ,
306+ reason = "mtp is not supported on v1" )
209307@pytest .mark .parametrize (
210308 "common_llm_kwargs" ,
211309 [{
@@ -221,7 +319,7 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
221319 "dtype" : PRECISION ,
222320
223321 # Main model
224- "model_name" : MAIN_MODEL ,
322+ "model_name" : FLOAT_MODEL ,
225323
226324 # GPU memory utilization
227325 "gpu_memory_utilization" : 0.9
@@ -256,6 +354,8 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
256354 batch_size , output_len , seed )
257355
258356
357+ @pytest .mark .skipif (os .getenv ("VLLM_USE_V1" ) == "1" ,
358+ reason = "mtp is not supported on v1" )
259359@pytest .mark .parametrize (
260360 "common_llm_kwargs" ,
261361 [{
@@ -266,7 +366,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
266366 "dtype" : PRECISION ,
267367
268368 # Main model
269- "model_name" : MAIN_MODEL ,
369+ "model_name" : FLOAT_MODEL ,
270370
271371 # GPU memory utilization
272372 "gpu_memory_utilization" : 0.9
@@ -305,6 +405,8 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
305405 batch_size , output_len , seed )
306406
307407
408+ @pytest .mark .skipif (os .getenv ("VLLM_USE_V1" ) == "1" ,
409+ reason = "mtp is not supported on v1" )
308410@pytest .mark .parametrize (
309411 "common_llm_kwargs" ,
310412 [{
@@ -315,7 +417,7 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
315417 "dtype" : PRECISION ,
316418
317419 # Main model
318- "model_name" : MAIN_MODEL ,
420+ "model_name" : FLOAT_MODEL ,
319421
320422 # GPU memory utilization
321423 "gpu_memory_utilization" : 0.9
0 commit comments