55import torch
66import torch .nn as nn
77from PIL import Image
8- from transformers import CLIPVisionConfig , LlavaConfig , SiglipVisionConfig
8+ from transformers import (CLIPVisionConfig , LlavaConfig , PixtralVisionConfig ,
9+ SiglipVisionConfig )
910
1011from vllm .attention import AttentionMetadata
1112from vllm .config import CacheConfig , MultiModalConfig
2223 dummy_seq_data_for_clip , get_max_clip_image_tokens ,
2324 input_processor_for_clip )
2425from .interfaces import SupportsMultiModal , SupportsPP
26+ from .pixtral import (PixtralHFVisionModel , dummy_image_for_pixtral_hf ,
27+ dummy_seq_data_for_pixtral_hf ,
28+ get_max_pixtral_hf_image_tokens ,
29+ input_processor_for_pixtral_hf )
2530from .siglip import (SiglipVisionModel , dummy_image_for_siglip ,
2631 dummy_seq_data_for_siglip , get_max_siglip_image_tokens ,
2732 input_processor_for_siglip )
3136
3237class LlavaImagePixelInputs (TypedDict ):
3338 type : Literal ["pixel_values" ]
34- data : torch .Tensor
35- """Shape: `(batch_size * num_images, num_channels, height, width)`"""
39+ data : Union [torch .Tensor , List [torch .Tensor ]]
40+ """
41+ Shape: `(batch_size * num_images, num_channels, height, width)`
42+
43+ Note that `height` or `width` may be different per batch and image,
44+ in which case the data is passed as a list instead of a batched tensor.
45+ """
3646
3747
3848class LlavaImageEmbeddingInputs (TypedDict ):
@@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext):
7787 num_image_tokens = get_max_clip_image_tokens (vision_config )
7888 elif isinstance (vision_config , SiglipVisionConfig ):
7989 num_image_tokens = get_max_siglip_image_tokens (vision_config )
90+ elif isinstance (vision_config , PixtralVisionConfig ):
91+ num_image_tokens = get_max_pixtral_hf_image_tokens (vision_config )
8092 else :
8193 msg = f"Unsupported vision config: { type (vision_config )} "
8294 raise NotImplementedError (msg )
@@ -120,6 +132,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
120132
121133 mm_data = dummy_image_for_siglip (vision_config , num_images )
122134 return seq_data , mm_data
135+ elif isinstance (vision_config , PixtralVisionConfig ):
136+ seq_data = dummy_seq_data_for_pixtral_hf (
137+ vision_config ,
138+ seq_len ,
139+ num_images ,
140+ image_token_id = hf_config .image_token_index ,
141+ image_feature_size_override = image_feature_size ,
142+ )
143+
144+ mm_data = dummy_image_for_pixtral_hf (vision_config , num_images )
145+ return seq_data , mm_data
123146
124147 msg = f"Unsupported vision config: { type (vision_config )} "
125148 raise NotImplementedError (msg )
@@ -163,6 +186,15 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
163186 image_token_id = hf_config .image_token_index ,
164187 image_feature_size_override = image_feature_size ,
165188 )
189+ elif isinstance (vision_config , PixtralVisionConfig ):
190+ # We ignore image_feature_size_override since we have non-uniform
191+ # image sizes for Pixtral
192+ return input_processor_for_pixtral_hf (
193+ model_config ,
194+ vision_config ,
195+ inputs ,
196+ image_token_id = hf_config .image_token_index ,
197+ )
166198
167199 msg = f"Unsupported vision config: { type (vision_config )} "
168200 raise NotImplementedError (msg )
@@ -189,6 +221,9 @@ def _init_vision_tower(hf_config: LlavaConfig):
189221 vision_config ,
190222 num_hidden_layers_override = num_hidden_layers ,
191223 )
224+ elif isinstance (vision_config , PixtralVisionConfig ):
225+ # TODO: allow layer override?
226+ return PixtralHFVisionModel (vision_config )
192227
193228 msg = f"Unsupported vision config: { type (vision_config )} "
194229 raise NotImplementedError (msg )
@@ -210,6 +245,15 @@ def __init__(self,
210245 self .config = config
211246 self .multimodal_config = multimodal_config
212247
248+ # NOTE: These are special cases for Pixtral-12B in the HF-format
249+ # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
250+ if (config .text_config .architectures is None
251+ and config .text_config .model_type == "mistral" ):
252+ config .text_config .architectures = ["MistralForCausalLM" ]
253+ if (config .projector_hidden_act is None
254+ and config .vision_config .hidden_act == "gelu" ):
255+ config .projector_hidden_act = "gelu"
256+
213257 # TODO: Optionally initializes this for supporting embeddings.
214258 self .vision_tower = _init_vision_tower (config )
215259 self .multi_modal_projector = LlavaMultiModalProjector (
@@ -246,6 +290,7 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
246290 def _parse_and_validate_image_input (
247291 self , ** kwargs : object ) -> Optional [LlavaImageInputs ]:
248292 pixel_values = kwargs .pop ("pixel_values" , None )
293+ image_sizes = kwargs .pop ("image_sizes" , None )
249294 image_embeds = kwargs .pop ("image_embeds" , None )
250295
251296 if pixel_values is None and image_embeds is None :
@@ -256,6 +301,26 @@ def _parse_and_validate_image_input(
256301 raise ValueError ("Incorrect type of pixel values. "
257302 f"Got type: { type (pixel_values )} " )
258303
304+ # Case for models like PixtralHF that have dynamic image sizes
305+ # so we need to produce a list of tensors
306+ if image_sizes is not None :
307+ images = pixel_values
308+ if isinstance (images , torch .Tensor ):
309+ # if passed as batch take all images
310+ NN , N , B , C , W , H = images .shape
311+ images = images .reshape (NN * N * B , C , W , H )
312+ images = [images [i ] for i in range (images .size (0 ))]
313+ elif isinstance (images , list ):
314+ # if passed as list flatten lists of tensors
315+ while isinstance (images , list ) and len (images ) == 1 :
316+ images = images [0 ]
317+
318+ # TODO: Add validation based on image_sizes
319+ return LlavaImagePixelInputs (
320+ type = "pixel_values" ,
321+ data = images ,
322+ )
323+
259324 return LlavaImagePixelInputs (
260325 type = "pixel_values" ,
261326 data = self ._validate_pixel_values (
@@ -286,7 +351,8 @@ def _select_image_features(self, image_features: torch.Tensor, *,
286351
287352 def _image_pixels_to_features (
288353 self ,
289- vision_tower : Union [CLIPVisionModel , SiglipVisionModel ],
354+ vision_tower : Union [CLIPVisionModel , SiglipVisionModel ,
355+ PixtralHFVisionModel ],
290356 pixel_values : torch .Tensor ,
291357 ) -> torch .Tensor :
292358
0 commit comments