diff --git a/tests/long_term/test_accuracy.py b/tests/long_term/test_accuracy.py index c6eefa4e05..a9d9619100 100644 --- a/tests/long_term/test_accuracy.py +++ b/tests/long_term/test_accuracy.py @@ -19,6 +19,7 @@ import gc import multiprocessing +import sys from multiprocessing import Queue import lm_eval @@ -26,41 +27,85 @@ import torch # pre-trained model path on Hugging Face. -MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" -# Math reasoning benchmark (Grade School Math 8K). -TASK = "gsm8k" +MODEL_NAME = ["Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct"] +# Benchmark configuration mapping models to evaluation tasks: +# - Text model: GSM8K (grade school math reasoning) +# - Vision-language model: MMMU Art & Design validation (multimodal understanding) +TASK = { + "Qwen/Qwen2.5-0.5B-Instruct": "gsm8k", + "Qwen/Qwen2.5-VL-3B-Instruct": "mmmu_val_art_and_design" +} # Answer validation requiring format consistency. -FILTER = "exact_match,strict-match" +FILTER = { + "Qwen/Qwen2.5-0.5B-Instruct": "exact_match,strict-match", + "Qwen/Qwen2.5-VL-3B-Instruct": "acc,none" +} # 3% relative tolerance for numerical accuracy. RTOL = 0.03 # Baseline accuracy after VLLM optimization. -EXPECTED_VALUE = 0.316 +EXPECTED_VALUE = { + "Qwen/Qwen2.5-0.5B-Instruct": 0.316, + "Qwen/Qwen2.5-VL-3B-Instruct": 0.541 +} +# Maximum context length configuration for each model. +MAX_MODEL_LEN = { + "Qwen/Qwen2.5-0.5B-Instruct": 4096, + "Qwen/Qwen2.5-VL-3B-Instruct": 8192 +} +# Model types distinguishing text-only and vision-language models. +MODEL_TYPE = { + "Qwen/Qwen2.5-0.5B-Instruct": "vllm", + "Qwen/Qwen2.5-VL-3B-Instruct": "vllm-vlm" +} +# wrap prompts in a chat-style template. +APPLY_CHAT_TEMPLATE = {"vllm": False, "vllm-vlm": True} +# Few-shot examples handling as multi-turn dialogues. +FEWSHOT_AS_MULTITURN = {"vllm": False, "vllm-vlm": True} -def run_test(queue, more_args=None): - model_args = f"pretrained={MODEL_NAME},max_model_len=4096" - if more_args is not None: - model_args = f"{model_args},{more_args}" - results = lm_eval.simple_evaluate( - model="vllm", - model_args=model_args, - tasks=TASK, - batch_size="auto", - ) - result = results["results"][TASK][FILTER] - print("result:", result) - queue.put(result) - del results - torch.npu.empty_cache() - gc.collect() +def run_test(queue, model, max_model_len, model_type): + try: + if model_type == "vllm-vlm": + model_args = (f"pretrained={model},max_model_len={max_model_len}," + "dtype=auto,max_images=2") + else: + model_args = (f"pretrained={model},max_model_len={max_model_len}," + "dtype=auto") + results = lm_eval.simple_evaluate( + model=model_type, + model_args=model_args, + tasks=TASK[model], + batch_size="auto", + apply_chat_template=APPLY_CHAT_TEMPLATE[model_type], + fewshot_as_multiturn=FEWSHOT_AS_MULTITURN[model_type], + ) + result = results["results"][TASK[model]][FILTER[model]] + print("result:", result) + queue.put(result) + except Exception as e: + queue.put(e) + sys.exit(1) + finally: + gc.collect() + torch.npu.empty_cache() -def test_lm_eval_accuracy(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context(): +@pytest.mark.parametrize("model", MODEL_NAME) +@pytest.mark.parametrize("VLLM_USE_V1", ["0", "1"]) +def test_lm_eval_accuracy(monkeypatch: pytest.MonkeyPatch, model, VLLM_USE_V1): + if model == "Qwen/Qwen2.5-VL-3B-Instruct" and VLLM_USE_V1 == "1": + pytest.skip( + "Qwen2.5-VL-3B-Instruct is not supported when VLLM_USE_V1=1") + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", VLLM_USE_V1) result_queue: Queue[float] = multiprocessing.Queue() - p = multiprocessing.Process(target=run_test, args=(result_queue, )) + p = multiprocessing.Process(target=run_test, + args=(result_queue, model, + MAX_MODEL_LEN[model], + MODEL_TYPE[model])) p.start() p.join() result = result_queue.get() - assert (EXPECTED_VALUE - RTOL < result < EXPECTED_VALUE + RTOL), \ - f"Expected: {EXPECTED_VALUE}±{RTOL} | Measured: {result}" + print(result) + assert (EXPECTED_VALUE[model] - RTOL < result < EXPECTED_VALUE[model] + RTOL), \ + f"Expected: {EXPECTED_VALUE[model]}±{RTOL} | Measured: {result}"