1919
2020import gc
2121import multiprocessing
22+ import sys
2223from multiprocessing import Queue
2324
2425import lm_eval
2526import pytest
2627import torch
2728
2829# pre-trained model path on Hugging Face.
29- MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
30- # Math reasoning benchmark (Grade School Math 8K).
31- TASK = "gsm8k"
30+ MODEL_NAME = ["Qwen/Qwen2.5-0.5B-Instruct" , "Qwen/Qwen2.5-VL-3B-Instruct" ]
31+ # Benchmark configuration mapping models to evaluation tasks:
32+ # - Text model: GSM8K (grade school math reasoning)
33+ # - Vision-language model: MMMU Art & Design validation (multimodal understanding)
34+ TASK = {
35+ "Qwen/Qwen2.5-0.5B-Instruct" : "gsm8k" ,
36+ "Qwen/Qwen2.5-VL-3B-Instruct" : "mmmu_val_art_and_design"
37+ }
3238# Answer validation requiring format consistency.
33- FILTER = "exact_match,strict-match"
39+ FILTER = {
40+ "Qwen/Qwen2.5-0.5B-Instruct" : "exact_match,strict-match" ,
41+ "Qwen/Qwen2.5-VL-3B-Instruct" : "acc,none"
42+ }
3443# 3% relative tolerance for numerical accuracy.
3544RTOL = 0.03
3645# Baseline accuracy after VLLM optimization.
37- EXPECTED_VALUE = 0.316
46+ EXPECTED_VALUE = {
47+ "Qwen/Qwen2.5-0.5B-Instruct" : 0.316 ,
48+ "Qwen/Qwen2.5-VL-3B-Instruct" : 0.541
49+ }
50+ # Maximum context length configuration for each model.
51+ MAX_MODEL_LEN = {
52+ "Qwen/Qwen2.5-0.5B-Instruct" : 4096 ,
53+ "Qwen/Qwen2.5-VL-3B-Instruct" : 8192
54+ }
55+ # Model types distinguishing text-only and vision-language models.
56+ MODEL_TYPE = {
57+ "Qwen/Qwen2.5-0.5B-Instruct" : "vllm" ,
58+ "Qwen/Qwen2.5-VL-3B-Instruct" : "vllm-vlm"
59+ }
60+ # wrap prompts in a chat-style template.
61+ APPLY_CHAT_TEMPLATE = {
62+ "Qwen/Qwen2.5-0.5B-Instruct" : False ,
63+ "Qwen/Qwen2.5-VL-3B-Instruct" : True
64+ }
65+ # Few-shot examples handling as multi-turn dialogues.
66+ FEWSHOT_AS_MULTITURN = {
67+ "Qwen/Qwen2.5-0.5B-Instruct" : False ,
68+ "Qwen/Qwen2.5-VL-3B-Instruct" : True
69+ }
3870
3971
40- def run_test (queue , more_args = None ):
41- model_args = f"pretrained={ MODEL_NAME } ,max_model_len=4096"
42- if more_args is not None :
43- model_args = f"{ model_args } ,{ more_args } "
44- results = lm_eval .simple_evaluate (
45- model = "vllm" ,
46- model_args = model_args ,
47- tasks = TASK ,
48- batch_size = "auto" ,
49- )
50- result = results ["results" ][TASK ][FILTER ]
51- print ("result:" , result )
52- queue .put (result )
53- del results
54- torch .npu .empty_cache ()
55- gc .collect ()
72+ def run_test (queue , model , max_model_len , model_type ):
73+ try :
74+ if model_type == "vllm-vlm" :
75+ model_args = (f"pretrained={ model } ,max_model_len={ max_model_len } ,"
76+ "dtype=auto,max_images=2" )
77+ else :
78+ model_args = (f"pretrained={ model } ,max_model_len={ max_model_len } ,"
79+ "dtype=auto" )
80+ results = lm_eval .simple_evaluate (
81+ model = model_type ,
82+ model_args = model_args ,
83+ tasks = TASK [model ],
84+ batch_size = "auto" ,
85+ apply_chat_template = APPLY_CHAT_TEMPLATE [model ],
86+ fewshot_as_multiturn = FEWSHOT_AS_MULTITURN [model ],
87+ )
88+ result = results ["results" ][TASK [model ]][FILTER [model ]]
89+ print ("result:" , result )
90+ queue .put (result )
91+ except Exception as e :
92+ queue .put (e )
93+ sys .exit (1 )
94+ finally :
95+ gc .collect ()
96+ torch .npu .empty_cache ()
5697
5798
5899def test_lm_eval_accuracy (monkeypatch : pytest .MonkeyPatch ):
59- with monkeypatch .context ():
60- result_queue : Queue [float ] = multiprocessing .Queue ()
61- p = multiprocessing .Process (target = run_test , args = (result_queue , ))
62- p .start ()
63- p .join ()
64- result = result_queue .get ()
65- assert (EXPECTED_VALUE - RTOL < result < EXPECTED_VALUE + RTOL ), \
66- f"Expected: { EXPECTED_VALUE } ±{ RTOL } | Measured: { result } "
100+ for model in MODEL_NAME :
101+ with monkeypatch .context ():
102+ result_queue : Queue [float ] = multiprocessing .Queue ()
103+ p = multiprocessing .Process (target = run_test ,
104+ args = (result_queue , model ,
105+ MAX_MODEL_LEN [model ],
106+ MODEL_TYPE [model ]))
107+ p .start ()
108+ p .join ()
109+ result = result_queue .get ()
110+ print (result )
111+ assert (EXPECTED_VALUE [model ] - RTOL < result < EXPECTED_VALUE [model ] + RTOL ), \
112+ f"Expected: { EXPECTED_VALUE [model ]} ±{ RTOL } | Measured: { result } "
0 commit comments