Skip to content

Commit

Permalink
Add multimodal tests for qwen
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
  • Loading branch information
alex-jw-brooks committed Sep 1, 2024
1 parent 872b056 commit cb5802c
Showing 1 changed file with 164 additions and 13 deletions.
177 changes: 164 additions & 13 deletions tests/models/test_qwen.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,170 @@
from typing import Type
import pathlib
from typing import List, Optional, Type

import pytest
from transformers import AutoTokenizer

from ..conftest import HfRunner, VllmRunner
from vllm.model_executor.models.qwen import get_qwen_llm_inputs
from vllm.multimodal.utils import rescale_image_size

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close

pytestmark = pytest.mark.vlm

text_only_models = [
"Qwen/Qwen-7B-Chat" # Has no visual component
]

multimodal_models = ["Qwen/Qwen-VL"]

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"Picture 1: <img></img>\nWhat's the content of the image?: ",
"cherry_blossom":
"Picture 1: <img></img>\nWhat is the season?: ",
})


### Tests for multimodal Qwen models
@pytest.mark.parametrize("hf_input_text,vllm_input_text,num_images", [
("I have no image tags", "I have no image tags", 0),
("Picture 1: <img></img>\n", "Picture 1: <img></img>\n", 1),
("Picture 1: <img></img>\n", "<image>", 1),
("Picture 1: <img></img>\n Picture 2: <img></img>\n", "<image> <image>",
2),
])
def test_qwen_input_processor_tag_unification(hf_input_text, vllm_input_text,
num_images):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL",
trust_remote_code=True)
hf_tok_ids = tokenizer.encode(hf_input_text)
vllm_tok_ids = get_qwen_llm_inputs(
vllm_input_text,
tokenizer,
num_images,
multi_modal_data=None,
)["prompt_token_ids"]
assert len(vllm_tok_ids) == len(hf_tok_ids)
assert vllm_tok_ids == hf_tok_ids


def run_test(
tmp_path: pathlib.PosixPath,
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]

# Export the images to a tempdir and substitute it into the hf prompt;
# the contents between <img>/</img> will be ignored by VLLM, but the
# transformers implementation for the visual transformer parses this to
# reload it in the forward call; the contents are treated as a URL or a
# local path.
for idx, asset in enumerate(image_assets):
image_tmp_path = tmp_path / f"{asset.name}.jpg"
asset.pil_image.save(image_tmp_path)
HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace(
"<img></img>", f"<img>{image_tmp_path}</img>")

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

# Text only tests; the primary purpose of this test is to ensure that we can
# load Qwen models, e.g., Qwen/Qwen-7B-Chat, that do not have a visual config,
# without any problems.
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=2048,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", multimodal_models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
model, size_factors, dtype, max_tokens,
num_logprobs) -> None:
run_test(
tmp_path,
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


### Tests for language only Qwen models
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
Expand All @@ -27,19 +179,18 @@ def test_text_only_qwen_model(
max_tokens: int,
num_logprobs: int,
):
# This test checks language inputs only, since the visual component
# for qwen-vl is still unsupported in VLLM. In the near-future, the
# implementation and this test will be extended to consider
# visual inputs as well.
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
# the primary purpose of this test is to ensure that we can
# load Qwen models, e.g., Qwen/Qwen-7B-Chat, that do not have a visual
# config, without any problems.
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts,
max_tokens,
num_logprobs=num_logprobs,
)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts,
max_tokens,
num_logprobs=num_logprobs,
Expand Down

0 comments on commit cb5802c

Please sign in to comment.