forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core][VLM] Support image embeddings as input (vllm-project#6613)
- Loading branch information
1 parent
74494eb
commit e904c03
Showing
13 changed files
with
518 additions
and
139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from typing import List, Optional, Tuple, Type | ||
|
||
import pytest | ||
from transformers import AutoConfig, AutoTokenizer | ||
|
||
from vllm.sequence import SampleLogprobs | ||
|
||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets | ||
from .utils import check_logprobs_close | ||
|
||
pytestmark = pytest.mark.vlm | ||
|
||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ | ||
"stop_sign": | ||
"USER: <image>\nWhat's the content of the image?\nASSISTANT:", | ||
"cherry_blossom": | ||
"USER: <image>\nWhat is the season?\nASSISTANT:", | ||
}) | ||
|
||
models = [ | ||
"llava-hf/llava-1.5-7b-hf", | ||
] | ||
|
||
|
||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str, | ||
Optional[SampleLogprobs]], | ||
model: str): | ||
"""Sanitize vllm output to be comparable with hf output.""" | ||
output_ids, output_str, out_logprobs = vllm_output | ||
|
||
config = AutoConfig.from_pretrained(model) | ||
image_token_id = config.image_token_index | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model) | ||
eos_token_id = tokenizer.eos_token_id | ||
|
||
hf_output_ids = [ | ||
token_id for idx, token_id in enumerate(output_ids) | ||
if token_id != image_token_id or output_ids[idx - 1] != image_token_id | ||
] | ||
|
||
assert output_str[0] == " " | ||
hf_output_str = output_str[1:] | ||
if hf_output_ids[-1] == eos_token_id: | ||
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) | ||
|
||
return hf_output_ids, hf_output_str, out_logprobs | ||
|
||
|
||
def run_test( | ||
hf_runner: Type[HfRunner], | ||
vllm_runner: Type[VllmRunner], | ||
image_assets: _ImageAssets, | ||
model: str, | ||
*, | ||
size_factors: List[float], | ||
dtype: str, | ||
max_tokens: int, | ||
num_logprobs: int, | ||
tensor_parallel_size: int, | ||
distributed_executor_backend: Optional[str] = None, | ||
): | ||
"""Inference result should be the same between hf and vllm. | ||
All the image fixtures for the test is under tests/images. | ||
For huggingface runner, we provide the PIL images as input. | ||
For vllm runner, we provide MultiModalDataDict objects | ||
and corresponding vision language config as input. | ||
Note, the text input is also adjusted to abide by vllm contract. | ||
The text output is sanitized to be able to compare with hf. | ||
""" | ||
|
||
# vLLM to load from image embeddings | ||
vllm_images = [asset.image_embeds for asset in image_assets] | ||
|
||
# transformers to load from PIL images | ||
hf_images = [asset.pil_image for asset in image_assets] | ||
|
||
vllm_inputs_per_image = [( | ||
[prompt for _ in size_factors], | ||
[image for _ in size_factors], | ||
) for image, prompt in zip(vllm_images, HF_IMAGE_PROMPTS)] | ||
|
||
hf_inputs_per_image = [( | ||
[prompt for _ in size_factors], | ||
[image for _ in size_factors], | ||
) for image, prompt in zip(hf_images, HF_IMAGE_PROMPTS)] | ||
|
||
# NOTE: take care of the order. run vLLM first, and then run HF. | ||
# vLLM needs a fresh new process without cuda initialization. | ||
# if we run HF first, the cuda initialization will be done and it | ||
# will hurt multiprocessing backend with fork method (the default method). | ||
|
||
# max_model_len should be greater than image_feature_size | ||
with vllm_runner(model, | ||
dtype=dtype, | ||
tensor_parallel_size=tensor_parallel_size, | ||
distributed_executor_backend=distributed_executor_backend, | ||
enforce_eager=True) as vllm_model: | ||
vllm_outputs_per_image = [ | ||
vllm_model.generate_greedy_logprobs(prompts, | ||
max_tokens, | ||
num_logprobs=num_logprobs, | ||
images=images) | ||
for prompts, images in vllm_inputs_per_image | ||
] | ||
|
||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: | ||
hf_outputs_per_image = [ | ||
hf_model.generate_greedy_logprobs_limit(prompts, | ||
max_tokens, | ||
num_logprobs=num_logprobs, | ||
images=images) | ||
for prompts, images in hf_inputs_per_image | ||
] | ||
|
||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, | ||
vllm_outputs_per_image): | ||
# TODO: Check whether using original CLIPVisionModel can improve | ||
# consistency against HF | ||
check_logprobs_close( | ||
outputs_0_lst=hf_outputs, | ||
outputs_1_lst=[ | ||
vllm_to_hf_output(vllm_output, model) | ||
for vllm_output in vllm_outputs | ||
], | ||
name_0="hf", | ||
name_1="vllm", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", models) | ||
@pytest.mark.parametrize( | ||
"size_factors", | ||
[ | ||
# No image | ||
[], | ||
# Single-scale | ||
[1.0], | ||
# Single-scale, batched | ||
[1.0, 1.0, 1.0], | ||
], | ||
) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
@pytest.mark.parametrize("max_tokens", [128]) | ||
@pytest.mark.parametrize("num_logprobs", [5]) | ||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, | ||
dtype: str, max_tokens: int, num_logprobs: int) -> None: | ||
run_test( | ||
hf_runner, | ||
vllm_runner, | ||
image_assets, | ||
model, | ||
size_factors=size_factors, | ||
dtype=dtype, | ||
max_tokens=max_tokens, | ||
num_logprobs=num_logprobs, | ||
tensor_parallel_size=1, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.