Skip to content

Commit

Permalink
[Core][VLM] Support image embeddings as input (vllm-project#6613)
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 authored and sfc-gh-mkeralapura committed Aug 12, 2024
1 parent 74494eb commit e904c03
Show file tree
Hide file tree
Showing 13 changed files with 518 additions and 139 deletions.
11 changes: 11 additions & 0 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI
"multi_modal_data": {"image": image},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Inference with image embeddings as input
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": image_embeds},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
Expand Down
159 changes: 159 additions & 0 deletions tests/models/test_llava_image_embeds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoConfig, AutoTokenizer

from vllm.sequence import SampleLogprobs

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close

pytestmark = pytest.mark.vlm

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
"cherry_blossom":
"USER: <image>\nWhat is the season?\nASSISTANT:",
})

models = [
"llava-hf/llava-1.5-7b-hf",
]


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index

tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id

hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]

assert output_str[0] == " "
hf_output_str = output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)

return hf_output_ids, hf_output_str, out_logprobs


def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""

# vLLM to load from image embeddings
vllm_images = [asset.image_embeds for asset in image_assets]

# transformers to load from PIL images
hf_images = [asset.pil_image for asset in image_assets]

vllm_inputs_per_image = [(
[prompt for _ in size_factors],
[image for _ in size_factors],
) for image, prompt in zip(vllm_images, HF_IMAGE_PROMPTS)]

hf_inputs_per_image = [(
[prompt for _ in size_factors],
[image for _ in size_factors],
) for image, prompt in zip(hf_images, HF_IMAGE_PROMPTS)]

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in vllm_inputs_per_image
]

with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in hf_inputs_per_image
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
10 changes: 10 additions & 0 deletions vllm/assets/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Literal

import torch
from PIL import Image

from vllm.assets.base import get_vllm_public_assets
Expand All @@ -18,3 +19,12 @@ def pil_image(self) -> Image.Image:
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
s3_prefix=VLM_IMAGES_DIR)
return Image.open(image_path)

@property
def image_embeds(self) -> torch.Tensor:
"""
Image embeddings, only used for testing purposes with llava 1.5.
"""
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
s3_prefix=VLM_IMAGES_DIR)
return torch.load(image_path)
72 changes: 49 additions & 23 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -28,6 +28,29 @@
"language_model.model": "language_model",
}

# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265


class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""


class Blip2ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""


Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]


class Blip2QFormerMultiHeadAttention(nn.Module):

Expand Down Expand Up @@ -375,20 +398,6 @@ def forward(
return sequence_output


class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""


Blip2ImageInputs = Blip2ImagePixelInputs

# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265


def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
return hf_config.num_query_tokens

Expand Down Expand Up @@ -506,18 +515,31 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None

if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Blip2ImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)

raise AssertionError("This line should be unreachable.")

def _image_pixels_to_features(self, vision_model: BlipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
Expand All @@ -538,6 +560,10 @@ def _process_image_pixels(self,

def _process_image_input(self,
image_input: Blip2ImageInputs) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)

Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,13 @@ def input_processor_for_clip(
tokenizer = cached_get_tokenizer(model_config.tokenizer)

if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else:
image_feature_size = image_feature_size_override

Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def __init__(self,
cache_config=cache_config,
quant_config=quant_config)

def _parse_and_validate_image_input(self, **kwargs: object):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
image_patches = kwargs.pop("image_patches", None)

if isinstance(image_patches, torch.Tensor):
Expand All @@ -249,6 +250,13 @@ def _parse_and_validate_image_input(self, **kwargs: object):
data=image_patches)
return None

def _process_image_input(
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:

assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -261,8 +269,7 @@ def forward(
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings, _ = self.vision_embed_tokens(
image_input["data"])
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
Expand Down
Loading

0 comments on commit e904c03

Please sign in to comment.