3939 PromptUpdate , PromptUpdateDetails )
4040from vllm .multimodal .profiling import BaseDummyInputsBuilder , ProcessorInputs
4141from vllm .sequence import IntermediateTensors
42- from vllm .utils import flatten_2d_lists
4342
4443from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
4544from .utils import (AutoWeightsLoader , flatten_bn , maybe_prefix ,
@@ -66,10 +65,13 @@ class FuyuImagePatchInputs(TypedDict):
6665 This is used to split the embeddings which has the first two dimensions
6766 flattened just like `flat_data`.
6867 """
68+
6969 embed_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
7070 """
7171 A boolean mask indicating which image embeddings correspond
7272 to patch tokens.
73+
74+ Shape: `(batch_size * num_images, num_embeds)`
7375 """
7476
7577
@@ -322,16 +324,18 @@ def _validate_shape(d: torch.Tensor):
322324 def _parse_and_validate_image_input (
323325 self , ** kwargs : object ) -> Optional [FuyuImagePatchInputs ]:
324326 image_patches = kwargs .pop ("image_patches" , None )
325- embed_is_patch = kwargs .pop ("embed_is_patch" , None )
326327 if image_patches is not None :
327328 if not isinstance (image_patches , (torch .Tensor , list )):
328329 raise ValueError ("Incorrect type of image patches. "
329330 f"Got type: { type (image_patches )} " )
330331
332+ embed_is_patch = kwargs .pop ("embed_is_patch" )
331333 if not isinstance (embed_is_patch , (torch .Tensor , list )):
332334 raise ValueError ("Incorrect type of embed_is_patch. "
333335 f"Got type: { type (embed_is_patch )} " )
336+
334337 image_patches_flat = flatten_bn (image_patches )
338+ embed_is_patch = flatten_bn (embed_is_patch )
335339
336340 return FuyuImagePatchInputs (
337341 type = "image_patches" ,
@@ -351,20 +355,21 @@ def _process_image_input(
351355 assert self .vision_embed_tokens is not None
352356 vision_embeddings_flat , _ = self .vision_embed_tokens (
353357 image_patches_flat )
358+
354359 return vision_embeddings_flat .split (patches_per_image , dim = 0 )
355360
356361 def get_multimodal_embeddings (
357362 self , ** kwargs : object ) -> Optional [MultiModalEmbeddings ]:
358363 image_input = self ._parse_and_validate_image_input (** kwargs )
359364 if image_input is None :
360365 return None
361- vision_embeddings = self . _process_image_input ( image_input )
362- #return vision_embeddings
363- return flatten_2d_lists (
364- scatter_patch_features ( * args ) for args in zip (
365- vision_embeddings ,
366- image_input ["embed_is_patch" ],
367- ) )
366+
367+ image_features = self . _process_image_input ( image_input )
368+
369+ return scatter_patch_features (
370+ image_features ,
371+ image_input ["embed_is_patch" ],
372+ )
368373
369374 def get_input_embeddings (
370375 self ,
0 commit comments