From cc7835a63edf8258baf7875e62cc3362cd01c54d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 30 Jun 2024 11:44:25 +0800 Subject: [PATCH] [CI/Build] Reuse code for checking output consistency (#5988) --- .../test_basic_correctness.py | 15 +++---- .../basic_correctness/test_chunked_prefill.py | 15 +++---- tests/basic_correctness/test_preemption.py | 16 ++++---- .../test_basic_distributed_correctness.py | 15 +++---- .../test_chunked_prefill_distributed.py | 15 +++---- tests/models/test_big_models.py | 15 +++---- tests/models/test_llava.py | 18 +++++---- tests/models/test_llava_next.py | 18 +++++---- tests/models/test_models.py | 15 +++---- tests/models/test_phi3v.py | 18 +++++---- tests/models/utils.py | 40 ++++++++++++++++++- 11 files changed, 125 insertions(+), 75 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 6f44030feebb0..a7b0fef533ccb 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -8,6 +8,8 @@ from vllm import LLM +from ..models.utils import check_outputs_equal + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -46,10 +48,9 @@ def test_models( gpu_memory_utilization=0.7) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 48d6091282b89..767e0628765bd 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -8,6 +8,8 @@ """ import pytest +from ..models.utils import check_outputs_equal + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -54,10 +56,9 @@ def test_models( ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 7f20b2d934942..d60cc95d75433 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -12,6 +12,8 @@ from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, ENABLE_ARTIFICIAL_PREEMPT) +from ..models.utils import check_outputs_equal + MODELS = [ "facebook/opt-125m", ] @@ -94,13 +96,13 @@ def test_preemption( total_preemption = ( vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + assert ("is preempted by PreemptionMode.RECOMPUTE mode because there " "is not enough KV cache space." in caplog_vllm.text) # Ensure the count bucket of request-level histogram metrics matches diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index b8ae5b4c44f8d..1f5fff3e1930b 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -17,6 +17,8 @@ import pytest import torch +from ..models.utils import check_outputs_equal + MODELS = [ os.environ["TEST_DIST_MODEL"], ] @@ -48,10 +50,9 @@ def test_models( ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 4e4e468c4377a..fd89147ac14b0 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -16,6 +16,8 @@ import pytest import torch +from ..models.utils import check_outputs_equal + MODELS = [ os.environ["TEST_DIST_MODEL"], ] @@ -59,10 +61,9 @@ def test_models( ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index ef78283731775..c3e48b56ee58f 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -7,6 +7,8 @@ import pytest import torch +from .utils import check_outputs_equal + MODELS = [ "meta-llama/Llama-2-7b-hf", # "mistralai/Mistral-7B-v0.1", # Tested by test_mistral.py @@ -40,13 +42,12 @@ def test_models( with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) @pytest.mark.parametrize("model", MODELS) diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index f2dfd4bb8596f..c60b15afcd04d 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -6,6 +6,7 @@ from vllm.config import VisionLanguageConfig from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from .utils import check_outputs_equal pytestmark = pytest.mark.vlm @@ -109,14 +110,15 @@ def run_test( max_tokens, images=vllm_images) - for i in range(len(HF_IMAGE_PROMPTS)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_to_hf_output( - vllm_outputs[i], vlm_config, model_id) - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + hf_outputs, + [ + vllm_to_hf_output(vllm_output, vlm_config, model_id) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) @pytest.mark.parametrize("model_and_config", model_and_vl_config) diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index d36e503871ca9..940d5035ef3f2 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -6,6 +6,7 @@ from vllm.config import VisionLanguageConfig from ..conftest import IMAGE_ASSETS +from .utils import check_outputs_equal pytestmark = pytest.mark.vlm @@ -115,11 +116,12 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config, max_tokens, images=vllm_images) - for i in range(len(HF_IMAGE_PROMPTS)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_to_hf_output( - vllm_outputs[i], vlm_config, model_id) - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + hf_outputs, + [ + vllm_to_hf_output(vllm_output, vlm_config, model_id) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 4453b4b9f0523..4cd2cb665c8f0 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -7,6 +7,8 @@ """ import pytest +from .utils import check_outputs_equal + MODELS = [ "facebook/opt-125m", "gpt2", @@ -41,13 +43,12 @@ def test_models( with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) @pytest.mark.parametrize("model", MODELS) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index e7d5639494104..2e34fa8c14010 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -7,6 +7,7 @@ from vllm.utils import is_cpu from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from .utils import check_outputs_equal pytestmark = pytest.mark.vlm @@ -124,14 +125,15 @@ def run_test( max_tokens, images=vllm_images) - for i in range(len(HF_IMAGE_PROMPTS)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_to_hf_output( - vllm_outputs[i], vlm_config, model_id) - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_outputs_equal( + hf_outputs, + [ + vllm_to_hf_output(vllm_output, vlm_config, model_id) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) # Since we use _attn_implementation="eager" for hf_runner, here is diff --git a/tests/models/utils.py b/tests/models/utils.py index 3e49dfb331176..0d5e304d84463 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,7 +1,43 @@ -def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1): - """Compare the logprobs of two sequences generated by different models, +from typing import Dict, List, Tuple + +TokensText = Tuple[List[int], str] + + +def check_outputs_equal(outputs_0_lst: List[TokensText], + outputs_1_lst: List[TokensText], name_0: str, + name_1: str): + """ + Compare the two sequences generated by different models, + which should be equal. + """ + assert len(outputs_0_lst) == len(outputs_1_lst) + + for prompt_idx, (outputs_0, + outputs_1) in enumerate(zip(outputs_0_lst, + outputs_1_lst)): + output_ids_0, output_str_0 = outputs_0 + output_ids_1, output_str_1 = outputs_1 + + assert output_str_0 == output_str_1, (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + +TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] + + +def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], + outputs_1_lst: List[TokensTextLogprobs], name_0: str, + name_1: str): + """ + Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. """ + assert len(outputs_0_lst) == len(outputs_1_lst) + # Loop through responses to each prompt. for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst,