Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 25 additions & 33 deletions vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from vllm.model_executor.models.transformers import replace_linear_class
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalUUIDDict,
NestedTensors)
MultiModalKwargsItems, MultiModalUUIDDict)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
Expand All @@ -40,7 +39,7 @@
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)

# The image token id may be various
Expand All @@ -50,15 +49,15 @@
class DeepseekVL2ImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- bnp: Batch size * number of images * number of patches
- p: Number of patches
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
data: Annotated[torch.Tensor,
TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]


Expand Down Expand Up @@ -228,12 +227,8 @@ def _call_hf_processor(
tok_kwargs=tok_kwargs,
)

pixel_values = processed_outputs["pixel_values"]
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [x.prod().item() + 1 for x in images_spatial_crop]
pixel_values = pixel_values.split(patches_per_image)
processed_outputs["pixel_values"] = pixel_values
processed_outputs["num_patches"] = (
processed_outputs["images_spatial_crop"].prod(-1) + 1)

return processed_outputs

Expand All @@ -242,8 +237,11 @@ def _get_mm_fields_config(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))

return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
images_spatial_crop=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
Expand Down Expand Up @@ -318,6 +316,7 @@ def _cached_apply_hf_processor(
info=DeepseekVL2ProcessingInfo,
dummy_inputs=DeepseekVL2DummyInputsBuilder)
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True

hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"language.": "language_model.",
Expand Down Expand Up @@ -460,37 +459,30 @@ def _parse_and_validate_image_input(

if pixel_values is not None:
expected_h = expected_w = self.vision_config.image_size
return DeepseekVL2ImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values),
images_spatial_crop=flatten_bn(
images_spatial_crop,
concat=True),
resolve_bindings={
"h": expected_h,
"w": expected_w,
})
return DeepseekVL2ImagePixelInputs(
type="pixel_values",
data=pixel_values,
images_spatial_crop=images_spatial_crop,
resolve_bindings={
"h": expected_h,
"w": expected_w,
})

if image_embeds is not None:
return DeepseekVL2VImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)

raise AssertionError("This line should be unreachable.")

def _pixel_values_to_embedding(
self,
pixel_values: NestedTensors,
pixel_values: torch.Tensor,
images_spatial_crop: torch.Tensor,
) -> NestedTensors:
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
total_tiles = [x for x in pixel_values]

# [batch_all_tiles, 3, height, width]
total_tiles = torch.cat(total_tiles, dim=0)

) -> list[torch.Tensor]:
# [batch_all_tiles, vit_seq_len, c]
images_feature = self.vision.forward_features(total_tiles)
images_feature = self.vision.forward_features(pixel_values)

# [batch_all_tiles, hw, D]
images_embeds = self.projector(images_feature)
Expand Down Expand Up @@ -573,7 +565,7 @@ def _pixel_values_to_embedding(
return vision_embeddings

def _process_image_input(
self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor:
self, image_input: DeepseekVL2ImageInputs) -> list[torch.Tensor]:
if image_input["type"] == "image_embeds":
image_data = image_input["data"]
if is_list_of(image_data, torch.Tensor):
Expand Down
75 changes: 24 additions & 51 deletions vllm/model_executor/models/dots_ocr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -42,34 +42,38 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
DotsVisionConfig)
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .vision import run_dp_sharded_mrope_vision_model

IMAGE_TOKEN = "<|imgpad|>"


class DotsOCRImagePixelInputs(TypedDict):
type: Literal["pixel_values", "image_grid_thw"]
class DotsOCRImagePixelInputs(TensorSchema):
"""
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
"""
type: Literal["pixel_values"]

pixel_values: torch.Tensor
image_grid_thw: torch.Tensor
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]


class DotsOCRImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds", "image_grid_thw"]
image_embeds: torch.Tensor
"""Supported types:
- List[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the images.
- `hidden_size` must match the hidden size of language model backbone.
class DotsOCRImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- nf: Number of image features
- hs: Hidden size
- ni: Number of images
"""
type: Literal["image_embeds"]

image_grid_thw: torch.Tensor
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]


DotsOCRImageInputs = Union[DotsOCRImagePixelInputs,
Expand Down Expand Up @@ -654,6 +658,8 @@ def forward(self, hidden_states: torch.Tensor,
)
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
merge_by_field_config = True

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
".attn.qkv_proj.": ".attn.qkv.",
Expand Down Expand Up @@ -709,22 +715,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
architectures=["Qwen2ForCausalLM"],
)

def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[DotsOCRImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
Expand All @@ -735,28 +725,11 @@ def _parse_and_validate_image_input(
return None

if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")

if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")

return DotsOCRImagePixelInputs(type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw)

if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")

if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return DotsOCRImageEmbeddingInputs(type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw)
Expand Down
76 changes: 24 additions & 52 deletions vllm/model_executor/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Any, Callable, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Callable, Literal, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -56,6 +56,7 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
Expand Down Expand Up @@ -579,38 +580,38 @@ def load_weights(self, weights) -> set[str]:
# === Vision Inputs === #


class Ernie4_5_VLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
class Ernie4_5_VLImagePixelInputs(TensorSchema):
"""

grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
"""
type: Literal["pixel_values"]

pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]


Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs


class Ernie4_5_VLVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
class Ernie4_5_VLVideoPixelInputs(TensorSchema):
"""

video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`

This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * temporal_patch_size * patch_size *
patch_size
"""
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "cps")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]


Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs
Ernie4_5_VLVideoInputs = Ernie4_5_VLVideoPixelInputs

# === Vision Processor === #

Expand Down Expand Up @@ -1213,6 +1214,7 @@ def get_dummy_mm_data(
dummy_inputs=Ernie4_5_VLDummyInputsBuilder)
class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
merge_by_field_config = True

packed_modules_mapping = {
"qkv_proj": [
Expand Down Expand Up @@ -1325,22 +1327,6 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
def get_language_model(self) -> torch.nn.Module:
return self.language_model

def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return mm_input.reshape(-1, mm_input.shape[-1])
else:
return torch.concat(mm_input)

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
Expand All @@ -1350,15 +1336,6 @@ def _parse_and_validate_image_input(
return None

if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")

if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")

return Ernie4_5_VLImagePixelInputs(type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
Expand All @@ -1372,11 +1349,6 @@ def _parse_and_validate_video_input(
return None

if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos, "video pixel values")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")

return Ernie4_5_VLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
Expand Down
Loading