From 22fafe37847014d5af2f1ad9a43d7f5d090cc1a8 Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Sat, 8 Mar 2025 09:58:53 +0000 Subject: [PATCH] [Frontend] support image embeds Signed-off-by: chaunceyjiang --- docs/source/serving/multimodal_inputs.md | 67 +++++++++++++- vllm/entrypoints/chat_utils.py | 113 +++++++++++++++++++++-- vllm/multimodal/image.py | 19 ++++ vllm/multimodal/utils.py | 14 ++- 4 files changed, 201 insertions(+), 12 deletions(-) diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index c540bff2cf30..2e2016c95e4f 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -462,4 +462,69 @@ export VLLM_AUDIO_FETCH_TIMEOUT= ### Embedding Inputs -TBD +To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, +pass a tensor of shape to the corresponding field of the multi-modal dictionary. +#### Image Embedding Inputs +For image embeddings, you can pass the base64-encoded tensor to the `image_embeds` field. +The following example demonstrates how to pass image embeddings to the OpenAI server: + +```python +image_embedding = torch.load(...) +grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct + +buffer = io.BytesIO() +torch.save(image_embedding, buffer) +buffer.seek(0) +binary_data = buffer.read() +base64_image_embedding = base64.b64encode(binary_data).decode('utf-8') + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +# Basic usage - this is equivalent to the LLaVA example for offline inference +model = "llava-hf/llava-1.5-7b-hf" +embeds = { + "type": "image_embeds", + "image_embeds": f"{base64_image_embedding}" +} + +# Pass additional parameters (available to Qwen2-VL and MiniCPM-V) +model = "Qwen/Qwen2-VL-2B-Instruct" +embeds = { + "type": "image_embeds", + "image_embeds": { + "image_embeds": f"{base64_image_embedding}" , # Required + "image_grid_thw": f"{base64_image_grid_thw}" # Required by Qwen/Qwen2-VL-2B-Instruct + }, +} +model = "openbmb/MiniCPM-V-2_6" +embeds = { + "type": "image_embeds", + "image_embeds": { + "image_embeds": f"{base64_image_embedding}" , # Required + "image_sizes": f"{base64_image_sizes}" # Required by openbmb/MiniCPM-V-2_6 + }, +} +chat_completion = client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [ + { + "type": "text", + "text": "What's in this image?", + }, + embeds, + ], + }, +], + model=model, +) +``` + +:::{note} +Only one message can contain `{"type": "image_embeds"}`. +If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc. +::: diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 8f906cf1d80b..b51ade17def6 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False): """The type of the content part.""" +class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): + image_embeds: Required[Union[str, dict[str, str]]] + """ + The image embeddings. It can be either: + - A single base64 string. + - A dictionary where each value is a base64 string. + """ + type: Required[Literal["image_embeds"]] + """The type of the content part.""" + + class VideoURL(TypedDict, total=False): url: Required[str] """ @@ -109,6 +120,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, CustomChatCompletionContentSimpleImageParam, + ChatCompletionContentPartImageEmbedsParam, CustomChatCompletionContentSimpleAudioParam, CustomChatCompletionContentSimpleVideoParam, str] @@ -350,7 +362,7 @@ def resolve_chat_template_content_format( return detected_format -ModalityStr = Literal["image", "audio", "video"] +ModalityStr = Literal["image", "audio", "video", "image_embeds"] _T = TypeVar("_T") @@ -391,7 +403,7 @@ def _placeholder_str(self, modality: ModalityStr, hf_config = self._model_config.hf_config model_type = hf_config.model_type - if modality == "image": + if modality in ["image", "image_embeds"]: if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return f"<|image_{current_count}|>" @@ -470,10 +482,27 @@ def create_parser(self) -> "BaseMultiModalContentParser": class MultiModalItemTracker(BaseMultiModalItemTracker[object]): def all_mm_data(self) -> Optional[MultiModalDataDict]: - if self._items_by_modality: - return dict(self._items_by_modality) - - return None + if not self._items_by_modality: + return None + mm_inputs = {} + items_by_modality = dict(self._items_by_modality) + if "image" in items_by_modality and "image_embeds" in items_by_modality: + raise ValueError(\ + "Mixing raw image and embedding inputs is not allowed") + + if "image_embeds" in items_by_modality: + image_embeds_lst = items_by_modality["image_embeds"] + if len(image_embeds_lst) > 1: + raise ValueError(\ + "Only one message can have {'type': 'image_embeds'}") + mm_inputs["image"] = image_embeds_lst[0] + elif "image" in items_by_modality: + mm_inputs["image"] = items_by_modality["image"] # A list of images + elif "audio" in items_by_modality: + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + elif "video" in items_by_modality: + mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": return MultiModalContentParser(self) @@ -482,13 +511,31 @@ def create_parser(self) -> "BaseMultiModalContentParser": class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): async def all_mm_data(self) -> Optional[MultiModalDataDict]: - if self._items_by_modality: - return { + if not self._items_by_modality: + return None + mm_inputs = {} + items_by_modality = { modality: await asyncio.gather(*items) for modality, items in self._items_by_modality.items() } - return None + if "image" in items_by_modality and "image_embeds" in items_by_modality: + raise ValueError( + "Mixing raw image and embedding inputs is not allowed") + + if "image_embeds" in items_by_modality: + image_embeds_lst = items_by_modality["image_embeds"] + if len(image_embeds_lst) > 1: + raise ValueError( + "Only one message can have {'type': 'image_embeds'}") + mm_inputs["image"] = image_embeds_lst[0] + elif "image" in items_by_modality: + mm_inputs["image"] = items_by_modality["image"] # A list of images + elif "audio" in items_by_modality: + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + elif "video" in items_by_modality: + mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": return AsyncMultiModalContentParser(self) @@ -513,6 +560,11 @@ def mm_placeholder_counts(self) -> dict[str, int]: def parse_image(self, image_url: str) -> None: raise NotImplementedError + @abstractmethod + def parse_image_embeds(self, + image_embeds: Union[str, dict[str, str]]) -> None: + raise NotImplementedError + @abstractmethod def parse_audio(self, audio_url: str) -> None: raise NotImplementedError @@ -543,6 +595,21 @@ def parse_image(self, image_url: str) -> None: placeholder = self._tracker.add("image", image) self._add_placeholder(placeholder) + def parse_image_embeds(self, + image_embeds: Union[str, dict[str, str]]) -> None: + if isinstance(image_embeds, dict): + embeds = { + k: self._connector.fetch_image_embedding(v) + for k, v in image_embeds.items() + } + placeholder = self._tracker.add("image_embeds", embeds) + + if isinstance(image_embeds, str): + embedding = self._connector.fetch_image_embedding(image_embeds) + placeholder = self._tracker.add("image_embeds", embedding) + + self._add_placeholder(placeholder) + def parse_audio(self, audio_url: str) -> None: audio = self._connector.fetch_audio(audio_url) @@ -579,6 +646,25 @@ def parse_image(self, image_url: str) -> None: placeholder = self._tracker.add("image", image_coro) self._add_placeholder(placeholder) + def parse_image_embeds(self, + image_embeds: Union[str, dict[str, str]]) -> None: + future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() + + if isinstance(image_embeds, dict): + embeds = { + k: self._connector.fetch_image_embedding(v) + for k, v in image_embeds.items() + } + future.set_result(embeds) + + if isinstance(image_embeds, str): + embedding = self._connector.\ + fetch_image_embedding(image_embeds) + future.set_result(embedding) + + placeholder = self._tracker.add("image_embeds", future) + self._add_placeholder(placeholder) + def parse_audio(self, audio_url: str) -> None: audio_coro = self._connector.fetch_audio_async(audio_url) @@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], # No need to validate using Pydantic again _TextParser = partial(cast, ChatCompletionContentPartTextParam) _ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) @@ -700,6 +787,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], lambda part: _TextParser(part).get("text", ""), "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", ""), + "image_embeds": + lambda part: _ImageEmbedsParser(part).get("image_embeds", {}), "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""), "input_audio": @@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part( VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", + "image_embeds", "audio_url", "input_audio", "video_url") @@ -843,7 +933,10 @@ def _parse_chat_message_content_part( str_content = cast(str, content) mm_parser.parse_image(str_content) return {'type': 'image'} if wrap_dicts else None - + if part_type == "image_embeds": + content = cast(Union[str, dict[str, str]], content) + mm_parser.parse_image_embeds(content) + return {'type': 'image'} if wrap_dicts else None if part_type == "audio_url": str_content = cast(str, content) mm_parser.parse_audio(str_content) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 98ece8f806f1..ee990e870b82 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -134,3 +134,22 @@ def encode_base64( data = buffer.getvalue() return base64.b64encode(data).decode('utf-8') + + +class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): + + def __init__(self) -> None: + super().__init__() + + def load_bytes(self, data: bytes) -> torch.Tensor: + buffer = BytesIO(data) + return torch.load(buffer, weights_only=True) + + def load_base64(self, media_type: str, data: str) -> torch.Tensor: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> torch.Tensor: + return torch.load(filepath) + + def encode_base64(self, media: torch.Tensor) -> str: + return base64.b64encode(media.numpy()).decode('utf-8') diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 6e6c10b34a25..ad381e1d1d00 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -7,6 +7,7 @@ import numpy as np import numpy.typing as npt +import torch from PIL import Image import vllm.envs as envs @@ -16,7 +17,7 @@ from .audio import AudioMediaIO from .base import MediaIO -from .image import ImageMediaIO +from .image import ImageEmbeddingMediaIO, ImageMediaIO from .inputs import PlaceholderRange from .video import VideoMediaIO @@ -245,6 +246,17 @@ async def fetch_video_async( fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, ) + def fetch_image_embedding( + self, + data: str, + ) -> torch.Tensor: + """ + Load image embedding from a URL. + """ + image_embedding_io = ImageEmbeddingMediaIO() + + return image_embedding_io.load_base64("", data) + global_media_connector = MediaConnector() """The global :class:`MediaConnector` instance used by vLLM."""