2020from vllm .model_executor .models .transformers import replace_linear_class
2121from vllm .multimodal import MULTIMODAL_REGISTRY
2222from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig ,
23- MultiModalKwargsItems , MultiModalUUIDDict ,
24- NestedTensors )
23+ MultiModalKwargsItems , MultiModalUUIDDict )
2524from vllm .multimodal .parse import (ImageEmbeddingItems , ImageProcessorItems ,
2625 ImageSize , MultiModalDataItems )
2726from vllm .multimodal .processing import (BaseMultiModalProcessor ,
4039from vllm .utils .tensor_schema import TensorSchema , TensorShape
4140
4241from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
43- from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
42+ from .utils import (AutoWeightsLoader , WeightsMapper ,
4443 init_vllm_registered_model , maybe_prefix )
4544
4645# The image token id may be various
5049class DeepseekVL2ImagePixelInputs (TensorSchema ):
5150 """
5251 Dimensions:
53- - bn : Batch size * number of images
52+ - bnp : Batch size * number of images * number of patches
5453 - p: Number of patches
5554 - c: Number of channels (3)
5655 - h: Height of each image
5756 - w: Width of each image
5857 """
5958 type : Literal ["pixel_values" ]
60- data : Annotated [Union [ torch .Tensor , list [ torch . Tensor ]] ,
61- TensorShape ("bn " , "p" , 3 , "h" , "w" , dynamic_dims = {"p " })]
59+ data : Annotated [torch .Tensor ,
60+ TensorShape ("bnp " , 3 , "h" , "w" , dynamic_dims = {"bnp " })]
6261 images_spatial_crop : Annotated [torch .Tensor , TensorShape ("bn" , 2 )]
6362
6463
@@ -228,12 +227,8 @@ def _call_hf_processor(
228227 tok_kwargs = tok_kwargs ,
229228 )
230229
231- pixel_values = processed_outputs ["pixel_values" ]
232- # split pixel values into patches corresponding to each image
233- images_spatial_crop = processed_outputs ["images_spatial_crop" ]
234- patches_per_image = [x .prod ().item () + 1 for x in images_spatial_crop ]
235- pixel_values = pixel_values .split (patches_per_image )
236- processed_outputs ["pixel_values" ] = pixel_values
230+ processed_outputs ["num_patches" ] = (
231+ processed_outputs ["images_spatial_crop" ].prod (- 1 ) + 1 )
237232
238233 return processed_outputs
239234
@@ -242,8 +237,11 @@ def _get_mm_fields_config(
242237 hf_inputs : BatchFeature ,
243238 hf_processor_mm_kwargs : Mapping [str , object ],
244239 ) -> Mapping [str , MultiModalFieldConfig ]:
240+ num_patches = hf_inputs .get ("num_patches" , torch .empty (0 ))
241+
245242 return dict (
246- pixel_values = MultiModalFieldConfig .batched ("image" ),
243+ pixel_values = MultiModalFieldConfig .flat_from_sizes (
244+ "image" , num_patches ),
247245 images_spatial_crop = MultiModalFieldConfig .batched ("image" ),
248246 image_embeds = MultiModalFieldConfig .batched ("image" ),
249247 )
@@ -318,6 +316,7 @@ def _cached_apply_hf_processor(
318316 info = DeepseekVL2ProcessingInfo ,
319317 dummy_inputs = DeepseekVL2DummyInputsBuilder )
320318class DeepseekVLV2ForCausalLM (nn .Module , SupportsMultiModal , SupportsPP ):
319+ merge_by_field_config = True
321320
322321 hf_to_vllm_mapper = WeightsMapper (orig_to_new_prefix = {
323322 "language." : "language_model." ,
@@ -460,37 +459,30 @@ def _parse_and_validate_image_input(
460459
461460 if pixel_values is not None :
462461 expected_h = expected_w = self .vision_config .image_size
463- return DeepseekVL2ImagePixelInputs (type = "pixel_values" ,
464- data = flatten_bn (pixel_values ),
465- images_spatial_crop = flatten_bn (
466- images_spatial_crop ,
467- concat = True ),
468- resolve_bindings = {
469- "h" : expected_h ,
470- "w" : expected_w ,
471- })
462+ return DeepseekVL2ImagePixelInputs (
463+ type = "pixel_values" ,
464+ data = pixel_values ,
465+ images_spatial_crop = images_spatial_crop ,
466+ resolve_bindings = {
467+ "h" : expected_h ,
468+ "w" : expected_w ,
469+ })
472470
473471 if image_embeds is not None :
474472 return DeepseekVL2VImageEmbeddingInputs (
475473 type = "image_embeds" ,
476- data = flatten_bn ( image_embeds ) ,
474+ data = image_embeds ,
477475 )
478476
479477 raise AssertionError ("This line should be unreachable." )
480478
481479 def _pixel_values_to_embedding (
482480 self ,
483- pixel_values : NestedTensors ,
481+ pixel_values : torch . Tensor ,
484482 images_spatial_crop : torch .Tensor ,
485- ) -> NestedTensors :
486- # Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
487- total_tiles = [x for x in pixel_values ]
488-
489- # [batch_all_tiles, 3, height, width]
490- total_tiles = torch .cat (total_tiles , dim = 0 )
491-
483+ ) -> list [torch .Tensor ]:
492484 # [batch_all_tiles, vit_seq_len, c]
493- images_feature = self .vision .forward_features (total_tiles )
485+ images_feature = self .vision .forward_features (pixel_values )
494486
495487 # [batch_all_tiles, hw, D]
496488 images_embeds = self .projector (images_feature )
@@ -573,7 +565,7 @@ def _pixel_values_to_embedding(
573565 return vision_embeddings
574566
575567 def _process_image_input (
576- self , image_input : DeepseekVL2ImageInputs ) -> torch .Tensor :
568+ self , image_input : DeepseekVL2ImageInputs ) -> list [ torch .Tensor ] :
577569 if image_input ["type" ] == "image_embeds" :
578570 image_data = image_input ["data" ]
579571 if is_list_of (image_data , torch .Tensor ):
0 commit comments