Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][VLM] Support image embeddings as input #6613

Merged
merged 30 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -28,6 +28,25 @@
"language_model.model": "language_model",
}

# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265


class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""


class Blip2ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor


Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]


class Blip2QFormerMultiHeadAttention(nn.Module):

Expand Down Expand Up @@ -375,20 +394,6 @@ def forward(
return sequence_output


class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""


Blip2ImageInputs = Blip2ImagePixelInputs

# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265


def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
return hf_config.num_query_tokens

Expand Down Expand Up @@ -506,18 +511,29 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None

if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Blip2ImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)

def _image_pixels_to_features(self, vision_model: BlipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
Expand All @@ -538,6 +554,10 @@ def _process_image_pixels(self,

def _process_image_input(self,
image_input: Blip2ImageInputs) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)

Expand Down
23 changes: 20 additions & 3 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -62,6 +62,14 @@ class FuyuImagePixelInputs(TypedDict):
"""


class FuyuImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor


FuyuImageInputs = Union[FuyuImagePixelInputs, FuyuImageEmbeddingInputs]


def _calculate_num_image_tokens(
height: int,
width: int,
Expand Down Expand Up @@ -249,6 +257,16 @@ def _parse_and_validate_image_input(self, **kwargs: object):
data=image_patches)
return None

def _process_image_input(self,
image_input: FuyuImageInputs) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -261,8 +279,7 @@ def forward(
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings, _ = self.vision_embed_tokens(
image_input["data"])
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
Expand Down
38 changes: 35 additions & 3 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ class InternVLImagePixelInputs(TypedDict):
"""


class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]


InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]


# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
Expand Down Expand Up @@ -378,13 +387,23 @@ def _validate_shape(d: torch.Tensor):
return data

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternVLImagePixelInputs]:
self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_token_id = kwargs.pop("image_token_id", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)

self.img_context_token_id = image_token_id[0]

if not isinstance(pixel_values, (torch.Tensor, list)):
Expand All @@ -396,6 +415,19 @@ def _parse_and_validate_image_input(
data=self._validate_pixel_values(pixel_values),
)

def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"])

return image_embeds

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -409,7 +441,7 @@ def forward(
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vit_embeds = self.extract_feature(image_input["data"])
vit_embeds = self._process_image_input(image_input)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vit_embeds,
self.img_context_token_id)
Expand Down
55 changes: 37 additions & 18 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@
merge_vision_embeddings)


class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""


class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

Expand All @@ -49,15 +63,6 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
return hidden_states


class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""


LlavaImageInputs = LlavaImagePixelInputs


def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
Expand Down Expand Up @@ -210,18 +215,28 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None

if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
Expand Down Expand Up @@ -258,6 +273,10 @@ def _process_image_pixels(self,

def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
Expand Down
52 changes: 37 additions & 15 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@ class LlavaNextImagePixelInputs(TypedDict):
"""


LlavaNextImageInputs = LlavaNextImagePixelInputs
class LlavaNextImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs]


# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
Expand Down Expand Up @@ -208,7 +214,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
input_width=width,
)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

Expand Down Expand Up @@ -320,26 +326,38 @@ def _validate_shape(d: torch.Tensor):
return data

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None

if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")

return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")

return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
Expand Down Expand Up @@ -466,6 +484,10 @@ def _process_image_input(
self,
image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:

if image_input["type"] == "image_embeds":
return [image_input["data"]]

patch_embeddings = self._process_image_pixels(image_input)

image_sizes = image_input.get("image_sizes")
Expand Down
Loading
Loading