1616# This file is a part of the vllm-ascend project.
1717# Adapted from vllm-project/blob/main/tests/entrypoints/llm/test_accuracy.py
1818#
19-
2019import gc
2120import multiprocessing
21+ import os
22+ import signal
23+ import subprocess
2224import sys
25+ import time
2326from multiprocessing import Queue
2427
2528import lm_eval
2629import pytest
30+ import requests
2731import torch
2832
33+ SERVER_HOST = "127.0.0.1"
34+ SERVER_PORT = 8000
35+ HEALTH_URL = f"http://{ SERVER_HOST } :{ SERVER_PORT } /health"
36+ COMPLETIONS_URL = f"http://{ SERVER_HOST } :{ SERVER_PORT } /v1/completions"
37+
2938# pre-trained model path on Hugging Face.
30- MODEL_NAME = ["Qwen/Qwen2.5-0.5B-Instruct" , "Qwen/Qwen2.5-VL-3B-Instruct" ]
39+ # Qwen/Qwen2.5-0.5B-Instruct: accuracy test for unimodal model.
40+ # Qwen/Qwen2.5-VL-3B-Instruct: accuracy test for multimodal model.
41+ # Qwen/Qwen3-30B-A3B: accuracy test for EP.
42+ # deepseek-ai/DeepSeek-V2-Lite: accuracy test for TP.
43+ MODEL_NAME = [
44+ "Qwen/Qwen2.5-0.5B-Instruct" , "Qwen/Qwen2.5-VL-3B-Instruct" ,
45+ "Qwen/Qwen3-30B-A3B" , "deepseek-ai/DeepSeek-V2-Lite"
46+ ]
47+ # Qwen/Qwen2.5-7B-Instruct: accuracy test for DP
48+ MODEL_NAME_DP = ["Qwen/Qwen2.5-0.5B-Instruct" ]
49+
3150# Benchmark configuration mapping models to evaluation tasks:
3251# - Text model: GSM8K (grade school math reasoning)
3352# - Vision-language model: MMMU Art & Design validation (multimodal understanding)
3453TASK = {
3554 "Qwen/Qwen2.5-0.5B-Instruct" : "gsm8k" ,
36- "Qwen/Qwen2.5-VL-3B-Instruct" : "mmmu_val_art_and_design"
55+ "Qwen/Qwen2.5-VL-3B-Instruct" : "mmmu_val_art_and_design" ,
56+ "Qwen/Qwen3-30B-A3B" : "gsm8k" ,
57+ "deepseek-ai/DeepSeek-V2-Lite" : "gsm8k"
3758}
3859# Answer validation requiring format consistency.
3960FILTER = {
4061 "Qwen/Qwen2.5-0.5B-Instruct" : "exact_match,strict-match" ,
41- "Qwen/Qwen2.5-VL-3B-Instruct" : "acc,none"
62+ "Qwen/Qwen2.5-VL-3B-Instruct" : "acc,none" ,
63+ "Qwen/Qwen3-30B-A3B" : "exact_match,strict-match" ,
64+ "deepseek-ai/DeepSeek-V2-Lite" : "exact_match,strict-match"
4265}
4366# 3% relative tolerance for numerical accuracy.
4467RTOL = 0.03
4568# Baseline accuracy after VLLM optimization.
4669EXPECTED_VALUE = {
4770 "Qwen/Qwen2.5-0.5B-Instruct" : 0.316 ,
48- "Qwen/Qwen2.5-VL-3B-Instruct" : 0.541
71+ "Qwen/Qwen2.5-VL-3B-Instruct" : 0.541 ,
72+ "Qwen/Qwen3-30B-A3B" : 0.888 ,
73+ "deepseek-ai/DeepSeek-V2-Lite" : 0.376
4974}
5075# Maximum context length configuration for each model.
5176MAX_MODEL_LEN = {
5277 "Qwen/Qwen2.5-0.5B-Instruct" : 4096 ,
53- "Qwen/Qwen2.5-VL-3B-Instruct" : 8192
78+ "Qwen/Qwen2.5-VL-3B-Instruct" : 8192 ,
79+ "Qwen/Qwen3-30B-A3B" : 4096 ,
80+ "deepseek-ai/DeepSeek-V2-Lite" : 4096
5481}
5582# Model types distinguishing text-only and vision-language models.
5683MODEL_TYPE = {
5784 "Qwen/Qwen2.5-0.5B-Instruct" : "vllm" ,
58- "Qwen/Qwen2.5-VL-3B-Instruct" : "vllm-vlm"
85+ "Qwen/Qwen2.5-VL-3B-Instruct" : "vllm-vlm" ,
86+ "Qwen/Qwen3-30B-A3B" : "vllm" ,
87+ "deepseek-ai/DeepSeek-V2-Lite" : "vllm"
5988}
6089# wrap prompts in a chat-style template.
61- APPLY_CHAT_TEMPLATE = {"vllm" : False , "vllm-vlm" : True }
90+ APPLY_CHAT_TEMPLATE = {
91+ "Qwen/Qwen2.5-0.5B-Instruct" : False ,
92+ "Qwen/Qwen2.5-VL-3B-Instruct" : True ,
93+ "Qwen/Qwen3-30B-A3B" : False ,
94+ "deepseek-ai/DeepSeek-V2-Lite" : False
95+ }
6296# Few-shot examples handling as multi-turn dialogues.
63- FEWSHOT_AS_MULTITURN = {"vllm" : False , "vllm-vlm" : True }
97+ FEWSHOT_AS_MULTITURN = {
98+ "Qwen/Qwen2.5-0.5B-Instruct" : False ,
99+ "Qwen/Qwen2.5-VL-3B-Instruct" : True ,
100+ "Qwen/Qwen3-30B-A3B" : False ,
101+ "deepseek-ai/DeepSeek-V2-Lite" : False
102+ }
103+ # MORE_ARGS extra CLI args per model
104+ MORE_ARGS = {
105+ "Qwen/Qwen2.5-0.5B-Instruct" :
106+ None ,
107+ "Qwen/Qwen2.5-VL-3B-Instruct" :
108+ None ,
109+ "Qwen/Qwen3-30B-A3B" :
110+ "tensor_parallel_size=4,enable_expert_parallel=True,enforce_eager=True" ,
111+ "deepseek-ai/DeepSeek-V2-Lite" :
112+ "tensor_parallel_size=4,trust_remote_code=True,enforce_eager=True"
113+ }
114+
115+ multiprocessing .set_start_method ("spawn" , force = True )
64116
65117
66- def run_test (queue , model , max_model_len , model_type ):
118+ def get_available_npu_count ():
119+ return torch .npu .device_count ()
120+
121+
122+ def run_test (queue , model , max_model_len , model_type , more_args ):
67123 try :
68124 if model_type == "vllm-vlm" :
69125 model_args = (f"pretrained={ model } ,max_model_len={ max_model_len } ,"
70126 "dtype=auto,max_images=2" )
71127 else :
72128 model_args = (f"pretrained={ model } ,max_model_len={ max_model_len } ,"
73129 "dtype=auto" )
130+ if more_args is not None :
131+ model_args = f"{ model_args } ,{ more_args } "
74132 results = lm_eval .simple_evaluate (
75133 model = model_type ,
76134 model_args = model_args ,
77135 tasks = TASK [model ],
78136 batch_size = "auto" ,
79- apply_chat_template = APPLY_CHAT_TEMPLATE [model_type ],
80- fewshot_as_multiturn = FEWSHOT_AS_MULTITURN [model_type ],
137+ apply_chat_template = APPLY_CHAT_TEMPLATE [model ],
138+ fewshot_as_multiturn = FEWSHOT_AS_MULTITURN [model ],
81139 )
82140 result = results ["results" ][TASK [model ]][FILTER [model ]]
83141 print ("result:" , result )
84142 queue .put (result )
85143 except Exception as e :
86- queue .put (e )
144+ error_msg = f"{ type (e ).__name__ } : { str (e )} "
145+ queue .put (error_msg )
87146 sys .exit (1 )
88147 finally :
89148 gc .collect ()
@@ -93,19 +152,89 @@ def run_test(queue, model, max_model_len, model_type):
93152@pytest .mark .parametrize ("model" , MODEL_NAME )
94153@pytest .mark .parametrize ("VLLM_USE_V1" , ["0" , "1" ])
95154def test_lm_eval_accuracy (monkeypatch : pytest .MonkeyPatch , model , VLLM_USE_V1 ):
155+ npu_count = get_available_npu_count ()
96156 if model == "Qwen/Qwen2.5-VL-3B-Instruct" and VLLM_USE_V1 == "1" :
97157 pytest .skip (
98158 "Qwen2.5-VL-3B-Instruct is not supported when VLLM_USE_V1=1" )
99- with monkeypatch .context () as m :
100- m .setenv ("VLLM_USE_V1" , VLLM_USE_V1 )
159+ if (model == "Qwen/Qwen2.5-VL-3B-Instruct"
160+ or model == "Qwen/Qwen2.5-0.5B-Instruct" ) and npu_count != 1 :
161+ pytest .skip (
162+ "test accuracy for Qwen2.5-0.5B-Instruct and Qwen2.5-VL-3B-Instruct when tp != 1"
163+ )
164+ if (model == "Qwen/Qwen3-30B-A3B"
165+ or model == "deepseek-ai/DeepSeek-V2-Lite" ) and (
166+ os .getenv ("VLLM_USE_V1" ) != "1" or npu_count != 4 ):
167+ pytest .skip (
168+ "test ep accuracy for Qwen/Qwen3-30B-A3B when VLLM_USE_V1=1 and tp=4"
169+ )
170+ with monkeypatch .context ():
101171 result_queue : Queue [float ] = multiprocessing .Queue ()
102172 p = multiprocessing .Process (target = run_test ,
103173 args = (result_queue , model ,
104174 MAX_MODEL_LEN [model ],
105- MODEL_TYPE [model ]))
175+ MODEL_TYPE [model ], MORE_ARGS [ model ] ))
106176 p .start ()
107177 p .join ()
108178 result = result_queue .get ()
109179 print (result )
110180 assert (EXPECTED_VALUE [model ] - RTOL < result < EXPECTED_VALUE [model ] + RTOL ), \
111181 f"Expected: { EXPECTED_VALUE [model ]} ±{ RTOL } | Measured: { result } "
182+
183+
184+ @pytest .mark .parametrize ("max_tokens" , [10 ])
185+ @pytest .mark .parametrize ("model" , MODEL_NAME_DP )
186+ def test_lm_eval_accuracy_dp (model , max_tokens ):
187+ npu_count = get_available_npu_count ()
188+ if (model != "Qwen/Qwen2.5-0.5B-Instruct"
189+ or os .getenv ("VLLM_USE_V1" ) != "1" or npu_count != 4 ):
190+ pytest .skip (
191+ "test accuracy for DP when model is Qwen2.5-0.5B-Instruct and engine is V1"
192+ )
193+
194+ log_file = open ("accuracy.log" , "a" )
195+ cmd = [
196+ "vllm" , "serve" , model , "--tensor_parallel_size" , "2" ,
197+ "--data_parallel_size" , "2"
198+ ]
199+ server_proc = subprocess .Popen (cmd ,
200+ stdout = log_file ,
201+ stderr = subprocess .DEVNULL )
202+
203+ try :
204+ for _ in range (300 ):
205+ try :
206+ r = requests .get (HEALTH_URL , timeout = 1 )
207+ if r .status_code == 200 :
208+ break
209+ except requests .exceptions .RequestException :
210+ pass
211+ time .sleep (1 )
212+ else :
213+ pytest .fail (
214+ f"vLLM serve did not become healthy after 300s: { HEALTH_URL } " )
215+
216+ prompt = "bejing is a"
217+ payload = {
218+ "prompt" : prompt ,
219+ "max_tokens" : max_tokens ,
220+ "sampling_params" : {
221+ "temperature" : 0.0 ,
222+ "top_p" : 1.0 ,
223+ "seed" : 123
224+ }
225+ }
226+ resp = requests .post (COMPLETIONS_URL , json = payload , timeout = 30 )
227+ resp .raise_for_status ()
228+ data = resp .json ()
229+
230+ generated = data ["choices" ][0 ]["text" ].strip ()
231+ expected = "city in north china, it has many famous attractions"
232+ assert generated == expected , f"Expected `{ expected } `, got `{ generated } `"
233+
234+ finally :
235+ server_proc .send_signal (signal .SIGINT )
236+ try :
237+ server_proc .wait (timeout = 10 )
238+ except subprocess .TimeoutExpired :
239+ server_proc .kill ()
240+ server_proc .wait ()
0 commit comments