1717
1818from typing import Iterable , Union
1919
20+ import numpy as np
21+
2022from ...feature_extraction_utils import BatchFeature
2123from ...image_utils import ImageInput , get_image_size , to_numpy_array
22- from ...processing_utils import ProcessingKwargs , ProcessorMixin , Unpack
24+ from ...processing_utils import MultiModalData , ProcessingKwargs , ProcessorMixin , Unpack
2325from ...tokenization_utils_base import PreTokenizedInput , TextInput
2426from ...utils import logging
2527from ...video_utils import VideoInput
@@ -32,6 +34,7 @@ class PerceptionLMProcessorKwargs(ProcessingKwargs, total=False):
3234 _defaults = {
3335 "text_kwargs" : {
3436 "padding" : False ,
37+ "return_mm_token_type_ids" : False ,
3538 },
3639 }
3740
@@ -157,9 +160,17 @@ def __call__(
157160 prompt_strings .append (sample )
158161
159162 return_tensors = output_kwargs ["text_kwargs" ].pop ("return_tensors" , None )
160- text_inputs = self .tokenizer (prompt_strings , ** output_kwargs ["text_kwargs" ])
163+ return_mm_token_type_ids = output_kwargs ["text_kwargs" ].pop ("return_mm_token_type_ids" , False )
164+ text_inputs = self .tokenizer (prompt_strings , ** output_kwargs ["text_kwargs" ], return_tensors = None )
161165 self ._check_special_mm_tokens (prompt_strings , text_inputs , modalities = ["image" , "video" ])
162- return BatchFeature (data = {** text_inputs , ** image_inputs , ** videos_inputs }, tensor_type = return_tensors )
166+
167+ if return_mm_token_type_ids :
168+ array_ids = np .array (text_inputs ["input_ids" ])
169+ mm_token_type_ids = np .zeros_like (text_inputs ["input_ids" ])
170+ mm_token_type_ids [array_ids == self .image_token_id ] = 1
171+ text_inputs ["mm_token_type_ids" ] = mm_token_type_ids .tolist ()
172+
173+ return BatchFeature (data = {** text_inputs , ** image_inputs }, tensor_type = return_tensors )
163174
164175 def _expand_media_tokens (self , sample , media_token : str , media_iter : Iterable ):
165176 media_count = sample .count (media_token )
@@ -183,6 +194,50 @@ def _expand_media_tokens(self, sample, media_token: str, media_iter: Iterable):
183194 sample += sample_splits [- 1 ]
184195 return sample
185196
197+ def _get_num_multimodal_tokens (self , image_sizes = None , ** kwargs ):
198+ """
199+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
200+
201+ Args:
202+ image_sizes (`list[list[int]]`, *optional*):
203+ The input sizes formatted as (height, width) per each image.
204+
205+ Returns:
206+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
207+ input modalities, along with other useful data.
208+ """
209+
210+ vision_data = {}
211+ if image_sizes is not None :
212+ images_kwargs = PerceptionLMProcessorKwargs ._defaults .get ("images_kwargs" , {})
213+ images_kwargs .update (kwargs )
214+ tile_size = images_kwargs .get ("tile_size" , None ) or self .image_processor .tile_size
215+
216+ num_image_tokens = []
217+ num_image_patches = []
218+ for height , width in image_sizes :
219+ if self .image_processor .vision_input_type == "thumb+tile" :
220+ aspect_ratio = self .image_processor ._fit_image_to_canvas (
221+ img_width = width , img_height = height , tile_size = tile_size
222+ )
223+ if aspect_ratio is None :
224+ aspect_ratio = self .image_processor ._find_closest_aspect_ratio (
225+ img_width = width , img_height = height , tile_size = tile_size
226+ )
227+ num_tiles = aspect_ratio [0 ] * aspect_ratio [1 ] + 1 # base image and tiles
228+ else :
229+ num_tiles = 1
230+
231+ num_image_tokens .append (
232+ (tile_size // self .patch_size // self .pooling_ratio )
233+ * (tile_size // self .patch_size // self .pooling_ratio )
234+ * num_tiles
235+ )
236+ num_image_patches .append (num_tiles )
237+
238+ vision_data .update ({"num_image_tokens" : num_image_tokens , "num_image_patches" : num_image_patches })
239+ return MultiModalData (** vision_data )
240+
186241 def batch_decode (self , * args , ** kwargs ):
187242 """
188243 This method forwards all its arguments to PerceptionLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
0 commit comments