1818""" PyTorch Fuyu model."""
1919import math
2020from collections .abc import Iterable , Mapping , Sequence
21- from typing import Literal , Optional , Set , Tuple , TypedDict
21+ from typing import Literal , Optional , Set , Tuple , TypedDict , Union
2222
2323import torch
2424import torch .nn as nn
3939 PromptUpdate , PromptUpdateDetails )
4040from vllm .multimodal .profiling import BaseDummyInputsBuilder , ProcessorInputs
4141from vllm .sequence import IntermediateTensors
42+ from vllm .utils import flatten_2d_lists
4243
4344from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
4445from .utils import (AutoWeightsLoader , flatten_bn , maybe_prefix ,
4546 merge_multimodal_embeddings )
47+ from .vision import scatter_patch_features , select_patch_features
4648
4749# Cannot find the following 2 numbers from hf config.
4850_IMAGE_TOKEN_ID = 71011
@@ -64,6 +66,11 @@ class FuyuImagePatchInputs(TypedDict):
6466 This is used to split the embeddings which has the first two dimensions
6567 flattened just like `flat_data`.
6668 """
69+ embed_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
70+ """
71+ A boolean mask indicating which image embeddings correspond
72+ to patch tokens.
73+ """
6774
6875
6976class FuyuProcessingInfo (BaseProcessingInfo ):
@@ -183,6 +190,19 @@ def _call_hf_processor(
183190
184191 processed_outputs ["image_patches" ] = image_patches [0 ]
185192
193+ # get patch grid size for each image
194+ embed_is_patch = []
195+ for image in images :
196+ ncols , nrows = self .info .get_image_feature_grid_size (
197+ image_width = image .width ,
198+ image_height = image .height ,
199+ )
200+
201+ mask = torch .tensor (([True ] * ncols + [False ]) * nrows )
202+ embed_is_patch .append (mask )
203+
204+ processed_outputs ["embed_is_patch" ] = embed_is_patch
205+
186206 return processed_outputs
187207
188208 def _apply_hf_processor_tokens_only (
@@ -202,7 +222,8 @@ def _get_mm_fields_config(
202222 hf_inputs : BatchFeature ,
203223 hf_processor_mm_kwargs : Mapping [str , object ],
204224 ) -> Mapping [str , MultiModalFieldConfig ]:
205- return dict (image_patches = MultiModalFieldConfig .batched ("image" ))
225+ return dict (image_patches = MultiModalFieldConfig .batched ("image" ),
226+ embed_is_patch = MultiModalFieldConfig .batched ("image" ))
206227
207228 def _get_prompt_updates (
208229 self ,
@@ -301,18 +322,23 @@ def _validate_shape(d: torch.Tensor):
301322 def _parse_and_validate_image_input (
302323 self , ** kwargs : object ) -> Optional [FuyuImagePatchInputs ]:
303324 image_patches = kwargs .pop ("image_patches" , None )
325+ embed_is_patch = kwargs .pop ("embed_is_patch" , None )
304326 if image_patches is not None :
305327 if not isinstance (image_patches , (torch .Tensor , list )):
306328 raise ValueError ("Incorrect type of image patches. "
307329 f"Got type: { type (image_patches )} " )
308330
331+ if not isinstance (embed_is_patch , (torch .Tensor , list )):
332+ raise ValueError ("Incorrect type of embed_is_patch. "
333+ f"Got type: { type (embed_is_patch )} " )
309334 image_patches_flat = flatten_bn (image_patches )
310335
311336 return FuyuImagePatchInputs (
312337 type = "image_patches" ,
313338 flat_data = self ._validate_pixel_values (
314339 flatten_bn (image_patches_flat , concat = True )),
315340 patches_per_image = [x .size (0 ) for x in image_patches_flat ],
341+ embed_is_patch = embed_is_patch ,
316342 )
317343
318344 return None
@@ -333,7 +359,12 @@ def get_multimodal_embeddings(
333359 if image_input is None :
334360 return None
335361 vision_embeddings = self ._process_image_input (image_input )
336- return vision_embeddings
362+ #return vision_embeddings
363+ return flatten_2d_lists (
364+ scatter_patch_features (* args ) for args in zip (
365+ vision_embeddings ,
366+ image_input ["embed_is_patch" ],
367+ ))
337368
338369 def get_input_embeddings (
339370 self ,
@@ -343,8 +374,8 @@ def get_input_embeddings(
343374 inputs_embeds = self .language_model .get_input_embeddings (input_ids )
344375 if multimodal_embeddings is not None :
345376 inputs_embeds = merge_multimodal_embeddings (
346- input_ids , inputs_embeds , multimodal_embeddings ,
347- _IMAGE_TOKEN_ID )
377+ input_ids , inputs_embeds ,
378+ select_patch_features ( multimodal_embeddings ), _IMAGE_TOKEN_ID )
348379 return inputs_embeds
349380
350381 def forward (
0 commit comments