-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Initialize Fuyu-8B support (#3924)
Co-authored-by: Roger Wang <ywang@roblox.com>
- Loading branch information
Showing
6 changed files
with
844 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import requests | ||
from PIL import Image | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
|
||
def run_fuyu(): | ||
llm = LLM(model="adept/fuyu-8b", max_model_len=4096) | ||
|
||
# single-image prompt | ||
prompt = "What is the highest life expectancy at of male?\n" | ||
url = "https://huggingface.co/adept/fuyu-8b/resolve/main/chart.png" | ||
image = Image.open(requests.get(url, stream=True).raw) | ||
sampling_params = SamplingParams(temperature=0, max_tokens=64) | ||
|
||
outputs = llm.generate( | ||
{ | ||
"prompt": prompt, | ||
"multi_modal_data": { | ||
"image": image | ||
}, | ||
}, | ||
sampling_params=sampling_params) | ||
|
||
for o in outputs: | ||
generated_text = o.outputs[0].text | ||
print(generated_text) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_fuyu() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from typing import List, Optional, Tuple, Type | ||
|
||
import pytest | ||
|
||
from vllm.multimodal.utils import rescale_image_size | ||
from vllm.sequence import SampleLogprobs | ||
from vllm.utils import is_cpu | ||
|
||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets | ||
from .utils import check_logprobs_close | ||
|
||
pytestmark = pytest.mark.vlm | ||
|
||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ | ||
"stop_sign": "What's the content of the image?\n", # noqa: E501 | ||
"cherry_blossom": "What is the season?\n", | ||
"boardwalk": "What's in this image?\n", | ||
}) | ||
|
||
models = ["adept/fuyu-8b"] | ||
|
||
|
||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str, | ||
Optional[SampleLogprobs]]): | ||
"""Sanitize vllm output to be comparable with hf output.""" | ||
output_ids, output_str, out_logprobs = vllm_output | ||
|
||
hf_output_str = output_str.lstrip() + "|ENDOFTEXT|" | ||
|
||
return output_ids, hf_output_str, out_logprobs | ||
|
||
|
||
def run_test( | ||
hf_runner: Type[HfRunner], | ||
vllm_runner: Type[VllmRunner], | ||
image_assets: _ImageAssets, | ||
model: str, | ||
*, | ||
size_factors: List[float], | ||
dtype: str, | ||
max_tokens: int, | ||
num_logprobs: int, | ||
tensor_parallel_size: int, | ||
distributed_executor_backend: Optional[str] = None, | ||
): | ||
"""Inference result should be the same between hf and vllm. | ||
All the image fixtures for the test is under tests/images. | ||
For huggingface runner, we provide the PIL images as input. | ||
For vllm runner, we provide MultiModalDataDict objects | ||
and corresponding vision language config as input. | ||
Note, the text input is also adjusted to abide by vllm contract. | ||
The text output is sanitized to be able to compare with hf. | ||
""" | ||
images = [asset.pil_image for asset in image_assets] | ||
|
||
inputs_per_image = [( | ||
[prompt for _ in size_factors], | ||
[rescale_image_size(image, factor) for factor in size_factors], | ||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] | ||
|
||
# NOTE: take care of the order. run vLLM first, and then run HF. | ||
# vLLM needs a fresh new process without cuda initialization. | ||
# if we run HF first, the cuda initialization will be done and it | ||
# will hurt multiprocessing backend with fork method (the default method). | ||
|
||
# max_model_len should be greater than image_feature_size | ||
with vllm_runner(model, | ||
max_model_len=2560, | ||
max_num_seqs=1, | ||
dtype=dtype, | ||
tensor_parallel_size=tensor_parallel_size, | ||
distributed_executor_backend=distributed_executor_backend, | ||
enforce_eager=True) as vllm_model: | ||
vllm_outputs_per_image = [ | ||
vllm_model.generate_greedy_logprobs(prompts, | ||
max_tokens, | ||
num_logprobs=num_logprobs, | ||
images=vllm_images) | ||
for prompts, vllm_images in inputs_per_image | ||
] | ||
|
||
with hf_runner(model, dtype=dtype) as hf_model: | ||
hf_model.model.get_output_embeddings = lambda: \ | ||
hf_model.model.language_model.get_output_embeddings() | ||
eos_token_id = hf_model.processor.tokenizer.eos_token_id | ||
hf_outputs_per_image = [ | ||
hf_model.generate_greedy_logprobs_limit(prompts, | ||
max_tokens, | ||
num_logprobs=num_logprobs, | ||
images=hf_images, | ||
eos_token_id=eos_token_id) | ||
for prompts, hf_images in inputs_per_image | ||
] | ||
|
||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, | ||
vllm_outputs_per_image): | ||
check_logprobs_close( | ||
outputs_0_lst=hf_outputs, | ||
outputs_1_lst=[ | ||
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs | ||
], | ||
name_0="hf", | ||
name_1="vllm", | ||
) | ||
|
||
|
||
target_dtype = "half" | ||
if is_cpu(): | ||
target_dtype = "bfloat16" | ||
|
||
|
||
@pytest.mark.parametrize("model", models) | ||
@pytest.mark.parametrize( | ||
"size_factors", | ||
[ | ||
# No image | ||
[], | ||
# Single-scale | ||
[0.25], | ||
# Single-scale, batched | ||
[0.25, 0.25, 0.25], | ||
# Multi-scale | ||
[0.25, 0.2, 0.15], | ||
], | ||
) | ||
@pytest.mark.parametrize("dtype", [target_dtype]) | ||
@pytest.mark.parametrize("max_tokens", [128]) | ||
@pytest.mark.parametrize("num_logprobs", [10]) | ||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, | ||
dtype: str, max_tokens: int, num_logprobs: int) -> None: | ||
run_test( | ||
hf_runner, | ||
vllm_runner, | ||
image_assets, | ||
model, | ||
size_factors=size_factors, | ||
dtype=dtype, | ||
max_tokens=max_tokens, | ||
num_logprobs=num_logprobs, | ||
tensor_parallel_size=1, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.