Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][VLM] Support image embeddings as input #6613

Merged
merged 30 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -62,6 +62,14 @@ class FuyuImagePixelInputs(TypedDict):
"""


class FuyuImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor


FuyuImageInputs = Union[FuyuImagePixelInputs, FuyuImageEmbeddingInputs]


def _calculate_num_image_tokens(
height: int,
width: int,
Expand Down Expand Up @@ -249,6 +257,16 @@ def _parse_and_validate_image_input(self, **kwargs: object):
data=image_patches)
return None

def _process_image_input(self,
image_input: FuyuImageInputs) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -261,8 +279,7 @@ def forward(
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings, _ = self.vision_embed_tokens(
image_input["data"])
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
Expand Down
41 changes: 30 additions & 11 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -59,7 +59,12 @@ class LlavaImagePixelInputs(TypedDict):
"""Shape: `(batch_size, num_channels, height, width)`"""


LlavaImageInputs = LlavaImagePixelInputs
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


def get_max_llava_image_tokens(ctx: InputContext):
Expand Down Expand Up @@ -174,18 +179,28 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None

if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
Expand Down Expand Up @@ -219,6 +234,10 @@ def _process_image_pixels(self,

def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
Expand Down
52 changes: 37 additions & 15 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ class LlavaNextImagePixelInputs(TypedDict):
"""


LlavaNextImageInputs = LlavaNextImagePixelInputs
class LlavaNextImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs]


# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
Expand Down Expand Up @@ -187,7 +193,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
input_width=width,
)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
return
ywang96 marked this conversation as resolved.
Show resolved Hide resolved
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

Expand Down Expand Up @@ -285,26 +291,38 @@ def _validate_shape(d: torch.Tensor):
return data

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None

if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")

return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")

return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
Expand Down Expand Up @@ -425,6 +443,10 @@ def _process_image_pixels(

def _process_image_input(
self, image_input: LlavaNextImageInputs) -> BatchedTensors:

if image_input["type"] == "image_embeds":
return [image_input["data"]]

patch_embeddings = self._process_image_pixels(image_input)

image_sizes = image_input.get("image_sizes")
Expand Down
41 changes: 30 additions & 11 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
from PIL import Image
Expand Down Expand Up @@ -148,7 +148,13 @@ class PaliGemmaImagePixelInputs(TypedDict):
"""Shape: (batch_size, num_channels, height, width)"""


PaliGemmaImageInputs = PaliGemmaImagePixelInputs
class PaliGemmaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor


PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
PaliGemmaImageEmbeddingInputs]


@MULTIMODAL_REGISTRY.register_image_input_mapper()
Expand Down Expand Up @@ -198,18 +204,28 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None

if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

return PaliGemmaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return PaliGemmaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return PaliGemmaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)

def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
Expand All @@ -233,6 +249,9 @@ def _process_image_pixels(
def _process_image_input(
self, image_input: PaliGemmaImageInputs) -> torch.Tensor:

if image_input["type"] == "pixel_values":
return image_input["data"]

assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)

Expand Down
29 changes: 19 additions & 10 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,22 +503,31 @@ def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
return None

if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is None and image_embeds is None:
return None

if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")

if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))

return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
# TODO: Enable image embeddings for Phi3-Vision
if image_embeds is not None:
raise NotImplementedError("Embeddings input is not supported yet")

def forward(self,
input_ids: torch.Tensor,
Expand Down
6 changes: 5 additions & 1 deletion vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def _get_hf_image_processor(self, model_config: ModelConfig):
def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
model_config = ctx.model_config

# Raw image
if isinstance(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config)
if image_processor is None:
Expand All @@ -127,8 +129,10 @@ def _default_input_mapper(self, ctx: InputContext,
raise

return MultiModalInputs(batch_data)

# Image embedding
elif isinstance(data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
return MultiModalInputs({"image_embeds": data})

raise TypeError(f"Invalid image type: {type(data)}")

Expand Down
Loading