33
44from abc import abstractmethod
55from collections .abc import Iterable , Mapping , Sequence
6- from typing import (Final , Literal , Optional , Protocol , TypedDict , TypeVar ,
6+ from typing import (Annotated , Final , Literal , Optional , Protocol , TypeVar ,
77 Union , cast )
88
99import torch
3333 PromptUpdateDetails )
3434from vllm .multimodal .profiling import BaseDummyInputsBuilder
3535from vllm .sequence import IntermediateTensors
36+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
3637
3738from .clip import CLIPVisionModel
3839from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
4445from .vision import get_vision_encoder_info
4546
4647
47- class LlavaImagePixelInputs (TypedDict ):
48- type : Literal ["pixel_values" ]
49- pixel_values : torch .Tensor
48+ class LlavaImagePixelInputs (TensorSchema ):
5049 """
51- Shape: `(batch_size * num_images, num_channels, height, width)`
52-
50+ Dimensions:
51+ - bn: Batch size * number of images
52+ - c: Number of channels (3)
53+ - h: Height
54+ - w: Width
55+
5356 Note that `height` or `width` may be different per batch and image,
5457 in which case the data is passed as a list instead of a batched tensor.
5558 """
59+ type : Literal ["pixel_values" ] = "pixel_values"
60+ pixel_values : Annotated [torch .Tensor , TensorShape ("bn" , 3 , "h" , "w" )]
5661
5762
58- class PixtralHFImagePixelInputs (TypedDict ):
59- type : Literal ["pixel_values_pixtral" ]
60- pixel_values : Union [torch .Tensor , list [torch .Tensor ]]
63+ class PixtralHFImagePixelInputs (TensorSchema ):
6164 """
62- Shape: `(batch_size * num_images, num_channels, height, width)`
63-
65+ Dimensions:
66+ - bn: Batch size * number of images
67+ - c: Number of channels
68+ - h: Height
69+ - w: Width
70+
6471 Note that `height` or `width` may be different per batch and image,
6572 in which case the data is passed as a list instead of a batched tensor.
6673 """
74+ type : Literal ["pixel_values_pixtral" ] = "pixel_values_pixtral"
75+ pixel_values : Annotated [Union [torch .Tensor , list [torch .Tensor ]],
76+ TensorShape ("bn" , "c" , "h" , "w" )]
6777
6878
69- class LlavaImageEmbeddingInputs (TypedDict ):
70- type : Literal ["image_embeds" ]
71- data : torch .Tensor
72- """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
73-
74- `hidden_size` must match the hidden size of language model backbone.
79+ class LlavaImageEmbeddingInputs (TensorSchema ):
7580 """
81+ Dimensions:
82+ - bn: Batch size * number of images
83+ - ifs: Image feature size
84+ - hs: Hidden size (must match language model backbone)
85+ """
86+ type : Literal ["image_embeds" ] = "image_embeds"
87+ data : Annotated [torch .Tensor , TensorShape ("bn" , "ifs" , "hs" )]
7688
7789
7890LlavaImageInputs = Union [LlavaImagePixelInputs , PixtralHFImagePixelInputs ,
@@ -547,19 +559,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
547559 self .make_empty_intermediate_tensors = (
548560 self .language_model .make_empty_intermediate_tensors )
549561
550- def _validate_pixel_values (self , data : torch .Tensor ) -> torch .Tensor :
551- h = w = self .config .vision_config .image_size
552- expected_dims = (3 , h , w )
553- actual_dims = tuple (data .shape [1 :])
554-
555- if actual_dims != expected_dims :
556- expected_expr = ("batch_size" , * map (str , expected_dims ))
557- raise ValueError (
558- f"The expected shape of pixel values is { expected_expr } . "
559- f"You supplied { tuple (data .shape )} ." )
560-
561- return data
562-
563562 def _parse_and_validate_image_input (
564563 self , ** kwargs : object ) -> Optional [LlavaImageInputs ]:
565564 pixel_values = kwargs .pop ("pixel_values" , None )
@@ -579,10 +578,14 @@ def _parse_and_validate_image_input(
579578 pixel_values = flatten_bn (pixel_values ),
580579 )
581580
581+ expected_h = expected_w = self .config .vision_config .image_size
582582 return LlavaImagePixelInputs (
583583 type = "pixel_values" ,
584- pixel_values = self ._validate_pixel_values (
585- flatten_bn (pixel_values , concat = True )),
584+ pixel_values = flatten_bn (pixel_values , concat = True ),
585+ resolve_bindings = {
586+ "h" : expected_h ,
587+ "w" : expected_w
588+ },
586589 )
587590
588591 if image_embeds is not None :
0 commit comments