Skip to content

Commit 47a5c0b

Browse files
Isotr0pyheheda12345
authored andcommitted
[Bugfix] Fix Mllama interleaved images input support (vllm-project#15564)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
1 parent 8f0edd7 commit 47a5c0b

File tree

2 files changed

+73
-15
lines changed

2 files changed

+73
-15
lines changed

examples/offline_inference/vision_language_multi_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
190190
limit_mm_per_prompt={"image": len(image_urls)},
191191
)
192192

193-
placeholders = "<|image|>" * len(image_urls)
194-
prompt = f"{placeholders}<|begin_of_text|>{question}"
193+
img_prompt = "Given the first image <|image|> and the second image<|image|>"
194+
prompt = f"<|begin_of_text|>{img_prompt}, {question}?"
195195
return ModelRequestData(
196196
llm=llm,
197197
prompt=prompt,

vllm/model_executor/models/mllama.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,76 @@ def get_dummy_processor_inputs(
171171
class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
172172
):
173173

174+
def apply(
175+
self,
176+
prompt: Union[str, list[int]],
177+
mm_data: MultiModalDataDict,
178+
hf_processor_mm_kwargs: Mapping[str, object],
179+
return_mm_hashes: bool = False,
180+
) -> MultiModalEncDecInputs:
181+
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
182+
return_mm_hashes)
183+
184+
image_token_id = self.info.get_hf_config().image_token_index
185+
# Check that the number of image tokens in the decoder prompt matches
186+
# the number of images provided in mm_data
187+
num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id)
188+
image_data = mm_data.get("image", [])
189+
num_images = 1 if isinstance(image_data, Image) else len(image_data)
190+
if num_image_tokens != num_images:
191+
raise ValueError(
192+
f"The number of image tokens ({num_image_tokens}) must be"
193+
f" the same as the number of images ({num_images})")
194+
195+
# Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
196+
# P0 & P1 do cross attention with placeholder of <IMG0>
197+
# P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
198+
# Example input to encoder and decoder:
199+
# {
200+
# 'encoder': {
201+
# 'type': 'token',
202+
# 'prompt_token_ids': [128256, 128256, ..., 128256],
203+
# 'prompt': '<|image|><|image|>...<|image|>',
204+
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
205+
# },
206+
# 'decoder': {
207+
# 'type': 'token',
208+
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
209+
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
210+
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
211+
# },
212+
# }
213+
214+
if mm_data:
215+
# Since only the last group of consecutive images
216+
# are attended by the decoded tokens, we only need to
217+
# get the number of tokens for those images.
218+
token_per_chunk = self.info.get_token_per_chunk_from_config()
219+
num_decode_images = self._get_num_image_in_last_group(
220+
mm_inputs["prompt_token_ids"])
221+
num_encode_images = num_images - num_decode_images
222+
223+
# Set encoder prompt length based on the number of tiles.
224+
# This tells the block manager to allocate correct number
225+
# of slots for encoder tokens.
226+
num_tiles = mm_inputs["mm_kwargs"]["num_tiles"]
227+
decode_tiles = num_tiles[num_encode_images:num_images].sum().item()
228+
num_tokens = decode_tiles * token_per_chunk
229+
mm_inputs["encoder_prompt_token_ids"] = [image_token_id
230+
] * num_tokens
231+
mm_inputs["encoder_prompt"] = "<|image|>" * num_tokens
232+
233+
return mm_inputs
234+
235+
def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int:
236+
num_images = 0
237+
for token_id in prompt_token_ids[::-1]:
238+
if token_id == self.info.get_hf_config().image_token_index:
239+
num_images += 1
240+
elif num_images > 0:
241+
break
242+
return num_images
243+
174244
def _call_hf_processor(
175245
self,
176246
prompt: str,
@@ -188,19 +258,7 @@ def _call_hf_processor(
188258
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
189259
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
190260
processed_outputs[k] = processed_outputs[k].squeeze(0)
191-
# Example input to encoder and decoder:
192-
# {
193-
# 'encoder': {
194-
# 'type': 'token',
195-
# 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
196-
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
197-
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
198-
# },
199-
# 'decoder': {
200-
# 'type': 'token',
201-
# 'prompt_token_ids': [128000],
202-
# },
203-
# }
261+
204262
processed_token_ids = processed_outputs.pop("input_ids")
205263
start_idx, end_idx = 0, processed_token_ids.size(1)
206264
processed_prompt_text = tokenizer.decode(processed_token_ids[0])

0 commit comments

Comments
 (0)