Skip to content

Commit d77fa4f

Browse files
committed
add qwen2.5-vl
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 5fa70b6 commit d77fa4f

File tree

5 files changed

+83
-5
lines changed

5 files changed

+83
-5
lines changed

tests/conftest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import List, Optional, Tuple, TypeVar, Union
2222

2323
import numpy as np
24+
import contextlib
2425
import pytest
2526
import torch
2627
from PIL import Image
@@ -34,7 +35,7 @@
3435
from vllm.sampling_params import BeamSearchParams
3536
from vllm.utils import is_list_of
3637

37-
from tests.model_utils import (TokensTextLogprobs,
38+
from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs,
3839
TokensTextLogprobsPromptLogprobs)
3940

4041
logger = init_logger(__name__)
@@ -51,6 +52,8 @@
5152
def cleanup_dist_env_and_memory():
5253
destroy_model_parallel()
5354
destroy_distributed_environment()
55+
with contextlib.suppress(AssertionError):
56+
torch.distributed.destroy_process_group()
5457
gc.collect()
5558
torch.npu.empty_cache()
5659

@@ -340,3 +343,8 @@ def __exit__(self, exc_type, exc_value, traceback):
340343
@pytest.fixture(scope="session")
341344
def vllm_runner():
342345
return VllmRunner
346+
347+
348+
@pytest.fixture(params=list(PROMPT_TEMPLATES.keys()))
349+
def prompt_template(request):
350+
return PROMPT_TEMPLATES[request.param]

tests/model_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#
1919

2020
import warnings
21-
from typing import Dict, List, Optional, Sequence, Tuple, Union
21+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
2222

2323
import torch
2424
from vllm.config import ModelConfig, TaskOption
@@ -301,3 +301,16 @@ def build_model_context(model_name: str,
301301
limit_mm_per_prompt=limit_mm_per_prompt,
302302
)
303303
return InputContext(model_config)
304+
305+
306+
def qwen_prompt(questions: List[str]) -> List[str]:
307+
placeholder = "<|image_pad|>"
308+
return [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
309+
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
310+
f"{q}<|im_end|>\n<|im_start|>assistant\n") for q in questions]
311+
312+
313+
# Map of prompt templates for different models.
314+
PROMPT_TEMPLATES: dict[str, Callable] = {
315+
"qwen2.5vl": qwen_prompt,
316+
}

tests/multicard/test_offline_inference_distributed.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import pytest
2626
import vllm # noqa: F401
2727
from conftest import VllmRunner
28+
from vllm.assets.image import ImageAsset
2829

2930
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
3031

@@ -50,6 +51,34 @@ def test_models_distributed(model: str,
5051
vllm_model.generate_greedy(example_prompts, max_tokens)
5152

5253

54+
@pytest.mark.parametrize("model", ["Qwen/Qwen2.5-VL-32B-Instruct"])
55+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
56+
reason="qwen2.5_vl is not supported on v1")
57+
def test_multimodal(model: str, prompt_template, vllm_runner):
58+
image = ImageAsset("cherry_blossom") \
59+
.pil_image.convert("RGB")
60+
img_questions = [
61+
"What is the content of this image?",
62+
"Describe the content of this image in detail.",
63+
"What's in the image?",
64+
"Where is this image taken?",
65+
]
66+
images = [image] * len(img_questions)
67+
prompts = prompt_template(img_questions)
68+
with vllm_runner(model,
69+
max_model_len=4096,
70+
tensor_parallel_size=4,
71+
distributed_executor_backend="mp",
72+
mm_processor_kwargs={
73+
"min_pixels": 28 * 28,
74+
"max_pixels": 1280 * 28 * 28,
75+
"fps": 1,
76+
}) as vllm_model:
77+
vllm_model.generate_greedy(prompts=prompts,
78+
images=images,
79+
max_tokens=64)
80+
81+
5382
if __name__ == "__main__":
5483
import pytest
5584
pytest.main([__file__])

tests/ops/test_rotary_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,4 @@ def test_rotary_embedding_quant_with_leading_dim(
202202
ref_key,
203203
atol=DEFAULT_ATOL,
204204
rtol=DEFAULT_RTOL)
205+
torch.npu.empty_cache()

tests/singlecard/test_offline_inference.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525
import pytest
2626
import vllm # noqa: F401
2727
from conftest import VllmRunner
28+
from vllm.assets.image import ImageAsset
2829

2930
import vllm_ascend # noqa: F401
3031

31-
MODELS = [
32-
"Qwen/Qwen2.5-0.5B-Instruct",
33-
]
32+
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
33+
MULTIMODALITY_MODELS = ["Qwen/Qwen2.5-VL-3B-Instruct"]
34+
3435
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
3536

3637

@@ -53,6 +54,32 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
5354
vllm_model.generate_greedy(example_prompts, max_tokens)
5455

5556

57+
@pytest.mark.parametrize("model", MULTIMODALITY_MODELS)
58+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
59+
reason="qwen2.5_vl is not supported on v1")
60+
def test_multimodal(model: str, prompt_template, vllm_runner):
61+
image = ImageAsset("cherry_blossom") \
62+
.pil_image.convert("RGB")
63+
img_questions = [
64+
"What is the content of this image?",
65+
"Describe the content of this image in detail.",
66+
"What's in the image?",
67+
"Where is this image taken?",
68+
]
69+
images = [image] * len(img_questions)
70+
prompts = prompt_template(img_questions)
71+
with vllm_runner(model,
72+
max_model_len=4096,
73+
mm_processor_kwargs={
74+
"min_pixels": 28 * 28,
75+
"max_pixels": 1280 * 28 * 28,
76+
"fps": 1,
77+
}) as vllm_model:
78+
vllm_model.generate_greedy(prompts=prompts,
79+
images=images,
80+
max_tokens=64)
81+
82+
5683
if __name__ == "__main__":
5784
import pytest
5885
pytest.main([__file__])

0 commit comments

Comments
 (0)