Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_model_len=32768,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
Expand Down
32 changes: 9 additions & 23 deletions vllm/model_executor/models/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
# yapf: enable
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix


class Idefics3ImagePixelInputs(TensorSchema):
Expand All @@ -67,7 +67,7 @@ class Idefics3ImagePixelInputs(TensorSchema):
"""
type: Literal["pixel_values"]
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
pixel_attention_mask: torch.Tensor
pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]


Expand Down Expand Up @@ -569,6 +569,8 @@ def forward(
dummy_inputs=Idefics3DummyInputsBuilder)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA):
merge_by_field_config = True

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -621,37 +623,21 @@ def _parse_and_validate_image_input(
return None

if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")

return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
data=image_embeds,
)

if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

pixel_attention_mask = kwargs.pop("pixel_attention_mask")
if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel_attention_mask. "
f"Got type: {type(pixel_attention_mask)}")

num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")

expected_h = expected_w = self.config.vision_config.image_size

return Idefics3ImagePixelInputs(
type="pixel_values",
pixel_values=flatten_bn(pixel_values, concat=True),
pixel_attention_mask=flatten_bn(pixel_attention_mask,
concat=True),
num_patches=flatten_bn(num_patches, concat=True),
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches,
resolve_bindings={
"h": expected_h,
"w": expected_w
Expand Down
59 changes: 7 additions & 52 deletions vllm/model_executor/models/keye.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, VideoItem)
Expand All @@ -42,7 +42,6 @@
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
Expand Down Expand Up @@ -100,8 +99,7 @@ def smart_resize(
class KeyeImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- bnp: Batch size * Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
Expand All @@ -110,7 +108,7 @@ class KeyeImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: Annotated[
torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]


Expand All @@ -134,8 +132,7 @@ class KeyeImageEmbeddingInputs(TensorSchema):
class KeyeVideoPixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- bnp: Batch size * Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
Expand All @@ -144,7 +141,7 @@ class KeyeVideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]


Expand Down Expand Up @@ -1258,6 +1255,8 @@ def _get_mm_fields_config(


class BaseKeyeModule(nn.Module):
merge_by_field_config = True

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -1524,28 +1523,6 @@ def _build_projector(self,
prefix: str = "") -> nn.Module:
return Projector(text_config, vision_config, quant_config, prefix)

def _validate_and_reshape_mm_tensor(
self, mm_input: NestedTensors,
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim == 5:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return mm_input.reshape(-1, mm_input.shape[-1])
elif is_list_of(mm_input, torch.Tensor):
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
for p in mm_input):
return mm_input
return torch.concat(mm_input)

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
Expand All @@ -1556,23 +1533,13 @@ def _parse_and_validate_image_input(
return None

if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")

return KeyeImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)

if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")

return KeyeImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
Expand All @@ -1589,25 +1556,13 @@ def _parse_and_validate_video_input(
return None

if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos,
"video pixel values",
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")

return KeyeVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)

if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")

return KeyeVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
Expand Down
56 changes: 5 additions & 51 deletions vllm/model_executor/models/keye_vl1_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig,
MultiModalKwargsItems, VideoItem)
Expand Down Expand Up @@ -100,8 +100,7 @@ def get_num_patches(grid_thw: torch.Tensor,
class KeyeVL1_5ImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- bnp: Batch size * Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
Expand All @@ -111,7 +110,7 @@ class KeyeVL1_5ImagePixelInputs(TensorSchema):

pixel_values: Annotated[
torch.Tensor,
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]

image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]

Expand All @@ -137,8 +136,7 @@ class KeyeVL1_5ImageEmbeddingInputs(TensorSchema):
class KeyeVL1_5VideoPixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- bnp: Batch size * Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
Expand All @@ -147,7 +145,7 @@ class KeyeVL1_5VideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]

num_frames: torch.Tensor
Expand Down Expand Up @@ -483,24 +481,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.merge_size = config.vision_config.spatial_merge_size
super().__init__(vllm_config=vllm_config, prefix=prefix)

def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
expected_dim: int, name: str):
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == expected_dim:
return mm_input
elif mm_input.ndim == expected_dim + 1:
return mm_input.reshape(-1, *mm_input.shape[2:])
else:
raise ValueError(
f"{name} should be {expected_dim}D or "
f"batched {expected_dim}D tensor."
f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})")
else:
return torch.concat(mm_input)

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
Expand All @@ -511,23 +491,13 @@ def _parse_and_validate_image_input(
return None

if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, expected_dim=4, name="image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, expected_dim=2, name="image grid_thw")

return KeyeVL1_5ImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)

if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, expected_dim=2, name="image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, expected_dim=2, name="image grid_thw")

return KeyeVL1_5ImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
Expand All @@ -545,29 +515,13 @@ def _parse_and_validate_video_input(
return None

if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos,
expected_dim=4,
name="video pixel values",
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, expected_dim=2, name="video grid_thw")

num_frames = self._validate_and_reshape_mm_tensor(
num_frames, expected_dim=1, name="video num frames")

return KeyeVL1_5VideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
num_frames=num_frames)

if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, expected_dim=2, name="video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, expected_dim=2, name="video grid_thw")

return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
Expand Down
Loading
Loading