@@ -171,6 +171,76 @@ def get_dummy_processor_inputs(
171171class MllamaMultiModalProcessor (EncDecMultiModalProcessor [MllamaProcessingInfo ]
172172 ):
173173
174+ def apply (
175+ self ,
176+ prompt : Union [str , list [int ]],
177+ mm_data : MultiModalDataDict ,
178+ hf_processor_mm_kwargs : Mapping [str , object ],
179+ return_mm_hashes : bool = False ,
180+ ) -> MultiModalEncDecInputs :
181+ mm_inputs = super ().apply (prompt , mm_data , hf_processor_mm_kwargs ,
182+ return_mm_hashes )
183+
184+ image_token_id = self .info .get_hf_config ().image_token_index
185+ # Check that the number of image tokens in the decoder prompt matches
186+ # the number of images provided in mm_data
187+ num_image_tokens = mm_inputs ['prompt_token_ids' ].count (image_token_id )
188+ image_data = mm_data .get ("image" , [])
189+ num_images = 1 if isinstance (image_data , Image ) else len (image_data )
190+ if num_image_tokens != num_images :
191+ raise ValueError (
192+ f"The number of image tokens ({ num_image_tokens } ) must be"
193+ f" the same as the number of images ({ num_images } )" )
194+
195+ # Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
196+ # P0 & P1 do cross attention with placeholder of <IMG0>
197+ # P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
198+ # Example input to encoder and decoder:
199+ # {
200+ # 'encoder': {
201+ # 'type': 'token',
202+ # 'prompt_token_ids': [128256, 128256, ..., 128256],
203+ # 'prompt': '<|image|><|image|>...<|image|>',
204+ # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
205+ # },
206+ # 'decoder': {
207+ # 'type': 'token',
208+ # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
209+ # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
210+ # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
211+ # },
212+ # }
213+
214+ if mm_data :
215+ # Since only the last group of consecutive images
216+ # are attended by the decoded tokens, we only need to
217+ # get the number of tokens for those images.
218+ token_per_chunk = self .info .get_token_per_chunk_from_config ()
219+ num_decode_images = self ._get_num_image_in_last_group (
220+ mm_inputs ["prompt_token_ids" ])
221+ num_encode_images = num_images - num_decode_images
222+
223+ # Set encoder prompt length based on the number of tiles.
224+ # This tells the block manager to allocate correct number
225+ # of slots for encoder tokens.
226+ num_tiles = mm_inputs ["mm_kwargs" ]["num_tiles" ]
227+ decode_tiles = num_tiles [num_encode_images :num_images ].sum ().item ()
228+ num_tokens = decode_tiles * token_per_chunk
229+ mm_inputs ["encoder_prompt_token_ids" ] = [image_token_id
230+ ] * num_tokens
231+ mm_inputs ["encoder_prompt" ] = "<|image|>" * num_tokens
232+
233+ return mm_inputs
234+
235+ def _get_num_image_in_last_group (self , prompt_token_ids : List [int ]) -> int :
236+ num_images = 0
237+ for token_id in prompt_token_ids [::- 1 ]:
238+ if token_id == self .info .get_hf_config ().image_token_index :
239+ num_images += 1
240+ elif num_images > 0 :
241+ break
242+ return num_images
243+
174244 def _call_hf_processor (
175245 self ,
176246 prompt : str ,
@@ -188,19 +258,7 @@ def _call_hf_processor(
188258 processed_outputs ["num_tiles" ] = torch .tensor (num_tiles )
189259 for k in ('pixel_values' , 'aspect_ratio_ids' , "aspect_ratio_mask" ):
190260 processed_outputs [k ] = processed_outputs [k ].squeeze (0 )
191- # Example input to encoder and decoder:
192- # {
193- # 'encoder': {
194- # 'type': 'token',
195- # 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
196- # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
197- # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
198- # },
199- # 'decoder': {
200- # 'type': 'token',
201- # 'prompt_token_ids': [128000],
202- # },
203- # }
261+
204262 processed_token_ids = processed_outputs .pop ("input_ids" )
205263 start_idx , end_idx = 0 , processed_token_ids .size (1 )
206264 processed_prompt_text = tokenizer .decode (processed_token_ids [0 ])
0 commit comments