@@ -180,19 +180,66 @@ def apply(
180180 mm_inputs = super ().apply (prompt , mm_data , hf_processor_mm_kwargs ,
181181 return_mm_hashes )
182182
183+ image_token_id = self .info .get_hf_config ().image_token_index
183184 # Check that the number of image tokens in the decoder prompt matches
184185 # the number of images provided in mm_data
185- num_image_tokens = mm_inputs ['prompt_token_ids' ].count (
186- self .info .get_hf_config ().image_token_index )
186+ num_image_tokens = mm_inputs ['prompt_token_ids' ].count (image_token_id )
187187 image_data = mm_data .get ("image" , [])
188188 num_images = 1 if isinstance (image_data , Image ) else len (image_data )
189189 if num_image_tokens != num_images :
190190 raise ValueError (
191191 f"The number of image tokens ({ num_image_tokens } ) must be"
192192 f" the same as the number of images ({ num_images } )" )
193193
194+ # Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
195+ # P0 & P1 do cross attention with placeholder of <IMG0>
196+ # P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
197+ # Example input to encoder and decoder:
198+ # {
199+ # 'encoder': {
200+ # 'type': 'token',
201+ # 'prompt_token_ids': [128256, 128256, ..., 128256],
202+ # 'prompt': '<|image|><|image|>...<|image|>',
203+ # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
204+ # },
205+ # 'decoder': {
206+ # 'type': 'token',
207+ # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
208+ # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
209+ # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
210+ # },
211+ # }
212+
213+ if mm_data :
214+ # Since only the last group of consecutive images
215+ # are attended by the decoded tokens, we only need to
216+ # get the number of tokens for those images.
217+ token_per_chunk = self .info .get_token_per_chunk_from_config ()
218+ num_decode_images = self ._get_num_image_in_last_group (
219+ mm_inputs ["prompt_token_ids" ])
220+ num_encode_images = num_images - num_decode_images
221+
222+ # Set encoder prompt length based on the number of tiles.
223+ # This tells the block manager to allocate correct number
224+ # of slots for encoder tokens.
225+ num_tiles = mm_inputs ["mm_kwargs" ]["num_tiles" ]
226+ decode_tiles = num_tiles [num_encode_images :num_images ].sum ().item ()
227+ num_tokens = decode_tiles * token_per_chunk
228+ mm_inputs ["encoder_prompt_token_ids" ] = [image_token_id
229+ ] * num_tokens
230+ mm_inputs ["encoder_prompt" ] = "<|image|>" * num_tokens
231+
194232 return mm_inputs
195233
234+ def _get_num_image_in_last_group (self , prompt_token_ids : List [int ]) -> int :
235+ num_images = 0
236+ for token_id in prompt_token_ids [::- 1 ]:
237+ if token_id == self .info .get_hf_config ().image_token_index :
238+ num_images += 1
239+ elif num_images > 0 :
240+ break
241+ return num_images
242+
196243 def _call_hf_processor (
197244 self ,
198245 prompt : str ,
@@ -210,19 +257,7 @@ def _call_hf_processor(
210257 processed_outputs ["num_tiles" ] = torch .tensor (num_tiles )
211258 for k in ('pixel_values' , 'aspect_ratio_ids' , "aspect_ratio_mask" ):
212259 processed_outputs [k ] = processed_outputs [k ].squeeze (0 )
213- # Example input to encoder and decoder:
214- # {
215- # 'encoder': {
216- # 'type': 'token',
217- # 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
218- # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
219- # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
220- # },
221- # 'decoder': {
222- # 'type': 'token',
223- # 'prompt_token_ids': [128000],
224- # },
225- # }
260+
226261 processed_token_ids = processed_outputs .pop ("input_ids" )
227262 start_idx , end_idx = 0 , processed_token_ids .size (1 )
228263 processed_prompt_text = tokenizer .decode (processed_token_ids [0 ])
0 commit comments