diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 0f87fb34bf32..3e3b4e59f833 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -20,8 +20,7 @@ from vllm.model_executor.models.transformers import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, MultiModalUUIDDict, - NestedTensors) + MultiModalKwargsItems, MultiModalUUIDDict) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -40,7 +39,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, +from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix) # The image token id may be various @@ -50,15 +49,15 @@ class DeepseekVL2ImagePixelInputs(TensorSchema): """ Dimensions: - - bn: Batch size * number of images + - bnp: Batch size * number of images * number of patches - p: Number of patches - c: Number of channels (3) - h: Height of each image - w: Width of each image """ type: Literal["pixel_values"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})] + data: Annotated[torch.Tensor, + TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})] images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)] @@ -228,12 +227,8 @@ def _call_hf_processor( tok_kwargs=tok_kwargs, ) - pixel_values = processed_outputs["pixel_values"] - # split pixel values into patches corresponding to each image - images_spatial_crop = processed_outputs["images_spatial_crop"] - patches_per_image = [x.prod().item() + 1 for x in images_spatial_crop] - pixel_values = pixel_values.split(patches_per_image) - processed_outputs["pixel_values"] = pixel_values + processed_outputs["num_patches"] = ( + processed_outputs["images_spatial_crop"].prod(-1) + 1) return processed_outputs @@ -242,8 +237,11 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + return dict( - pixel_values=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), images_spatial_crop=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -318,6 +316,7 @@ def _cached_apply_hf_processor( info=DeepseekVL2ProcessingInfo, dummy_inputs=DeepseekVL2DummyInputsBuilder) class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ "language.": "language_model.", @@ -460,37 +459,30 @@ def _parse_and_validate_image_input( if pixel_values is not None: expected_h = expected_w = self.vision_config.image_size - return DeepseekVL2ImagePixelInputs(type="pixel_values", - data=flatten_bn(pixel_values), - images_spatial_crop=flatten_bn( - images_spatial_crop, - concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w, - }) + return DeepseekVL2ImagePixelInputs( + type="pixel_values", + data=pixel_values, + images_spatial_crop=images_spatial_crop, + resolve_bindings={ + "h": expected_h, + "w": expected_w, + }) if image_embeds is not None: return DeepseekVL2VImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") def _pixel_values_to_embedding( self, - pixel_values: NestedTensors, + pixel_values: torch.Tensor, images_spatial_crop: torch.Tensor, - ) -> NestedTensors: - # Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width] - total_tiles = [x for x in pixel_values] - - # [batch_all_tiles, 3, height, width] - total_tiles = torch.cat(total_tiles, dim=0) - + ) -> list[torch.Tensor]: # [batch_all_tiles, vit_seq_len, c] - images_feature = self.vision.forward_features(total_tiles) + images_feature = self.vision.forward_features(pixel_values) # [batch_all_tiles, hw, D] images_embeds = self.projector(images_feature) @@ -573,7 +565,7 @@ def _pixel_values_to_embedding( return vision_embeddings def _process_image_input( - self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor: + self, image_input: DeepseekVL2ImageInputs) -> list[torch.Tensor]: if image_input["type"] == "image_embeds": image_data = image_input["data"] if is_list_of(image_data, torch.Tensor): diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 4845f19bcbc4..e68777aab6bf 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -42,34 +42,38 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, DotsVisionConfig) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .vision import run_dp_sharded_mrope_vision_model IMAGE_TOKEN = "<|imgpad|>" -class DotsOCRImagePixelInputs(TypedDict): - type: Literal["pixel_values", "image_grid_thw"] +class DotsOCRImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + """ + type: Literal["pixel_values"] - pixel_values: torch.Tensor - image_grid_thw: torch.Tensor + pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] -class DotsOCRImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds", "image_grid_thw"] - image_embeds: torch.Tensor - """Supported types: - - List[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. +class DotsOCRImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + """ + type: Literal["image_embeds"] - image_grid_thw: torch.Tensor + image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, @@ -654,6 +658,8 @@ def forward(self, hidden_states: torch.Tensor, ) class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True + hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ ".attn.qkv_proj.": ".attn.qkv.", @@ -709,22 +715,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): architectures=["Qwen2ForCausalLM"], ) - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[DotsOCRImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -735,28 +725,11 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - return DotsOCRImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return DotsOCRImageEmbeddingInputs(type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index a73ec4f88ffe..c62658fa4c21 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -25,7 +25,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch @@ -56,6 +56,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -579,38 +580,38 @@ def load_weights(self, weights) -> set[str]: # === Vision Inputs === # -class Ernie4_5_VLImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` +class Ernie4_5_VLImagePixelInputs(TensorSchema): """ - - grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * patch_size * patch_size """ + type: Literal["pixel_values"] + + pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs -class Ernie4_5_VLVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` +class Ernie4_5_VLVideoPixelInputs(TensorSchema): """ - - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * temporal_patch_size * patch_size * + patch_size """ + type: Literal["pixel_values_videos"] + pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "cps")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] -Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs +Ernie4_5_VLVideoInputs = Ernie4_5_VLVideoPixelInputs # === Vision Processor === # @@ -1213,6 +1214,7 @@ def get_dummy_mm_data( dummy_inputs=Ernie4_5_VLDummyInputsBuilder) class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": [ @@ -1325,22 +1327,6 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: def get_language_model(self) -> torch.nn.Module: return self.language_model - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return mm_input.reshape(-1, mm_input.shape[-1]) - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -1350,15 +1336,6 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - return Ernie4_5_VLImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw) @@ -1372,11 +1349,6 @@ def _parse_and_validate_video_input( return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return Ernie4_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 9e491c0b50d2..2ab2cf9b17b3 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -59,17 +59,14 @@ class FuyuImagePatchInputs(TensorSchema): type: Literal["image_patches"] = "image_patches" - flat_data: Annotated[ - torch.Tensor, - TensorShape("bnp", "fn"), - ] + image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")] patches_per_image: Annotated[list[int], TensorShape("bn")] """ The number of total patches for each image in the batch. This is used to split the embeddings which has the first two dimensions - flattened just like `flat_data`. + flattened just like `image_patches_flat`. """ @@ -174,28 +171,10 @@ def _call_hf_processor( tok_kwargs=tok_kwargs, ) - image_patches = processed_outputs.get("image_patches") - if image_patches is not None: - images = mm_data["images"] - assert isinstance(images, list) - - # Original output: (1, num_images, Pn, Px * Py * C) - # New output: (num_images, Pn, Px * Py * C) - # image_patches is a list with shape: - # (1, num_images, Pn, Px * Py * C) - # before Transformers 4.53 - if isinstance(image_patches, list): - assert len(image_patches) == 1 - assert (isinstance(image_patches[0], torch.Tensor) - and len(image_patches[0]) == len(images)) - processed_outputs["image_patches"] = image_patches[0] - # image_patches is a tensor with shape: - # (num_images, Pn, Px * Py * C) - # after Transformers 4.53 - elif isinstance(image_patches, torch.Tensor): - assert len(image_patches) == len(images) - else: - raise AssertionError("This line should be unreachable.") + image_patches = processed_outputs["image_patches"] + processed_outputs["image_patches"] = flatten_bn(image_patches) + processed_outputs["patches_per_image"] = torch.tensor( + [len(p) for p in image_patches]) return processed_outputs @@ -218,7 +197,13 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict(image_patches=MultiModalFieldConfig.batched("image")) + patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0)) + + return dict( + image_patches=MultiModalFieldConfig.flat_from_sizes( + "image", patches_per_image), + patches_per_image=MultiModalFieldConfig.batched("image"), + ) def _get_prompt_updates( self, @@ -263,6 +248,7 @@ def get_replacement_fuyu(item_idx: int): info=FuyuProcessingInfo, dummy_inputs=FuyuDummyInputsBuilder) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -306,29 +292,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: image_patches = kwargs.pop("image_patches", None) - if image_patches is not None: - image_patches_flat = flatten_bn(image_patches) - flat_data = flatten_bn(image_patches_flat, concat=True) - - return FuyuImagePatchInputs( - type="image_patches", - flat_data=flat_data, - patches_per_image=[x.size(0) for x in image_patches_flat], - resolve_bindings={"fn": self.image_feature_size}, - ) + patches_per_image = kwargs.pop("patches_per_image", None) - return None + if image_patches is None: + return None + + return FuyuImagePatchInputs( + type="image_patches", + image_patches_flat=image_patches, + patches_per_image=patches_per_image, + resolve_bindings={"fn": self.image_feature_size}, + ) def _process_image_input( self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings: - image_patches_flat = image_input["flat_data"] + image_patches_flat = image_input["image_patches_flat"] patches_per_image = image_input["patches_per_image"] assert self.vision_embed_tokens is not None vision_embeddings_flat, _ = self.vision_embed_tokens( image_patches_flat) - return vision_embeddings_flat.split(patches_per_image, dim=0) + return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0) def get_language_model(self) -> torch.nn.Module: return self.language_model