Skip to content

Commit eda742d

Browse files
Isotr0pyshreyankg
authored andcommitted
[VLM] Implement merged multimodal processor for Mllama (vllm-project#11427)
1 parent 80c186a commit eda742d

File tree

8 files changed

+456
-233
lines changed

8 files changed

+456
-233
lines changed

tests/models/encoder_decoder/vision_language/test_mllama.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
88
BatchEncoding)
99

10+
from vllm import LLM, SamplingParams
1011
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
1112
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
1213
global_force_attn_backend_context_manager)
13-
from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID,
14-
MllamaForConditionalGeneration)
14+
from vllm.model_executor.models.mllama import MllamaForConditionalGeneration
1515
from vllm.multimodal.image import rescale_image_size
1616
from vllm.sequence import SampleLogprobs
1717

@@ -21,6 +21,7 @@
2121
from ...utils import check_logprobs_close
2222

2323
_LIMIT_IMAGE_PER_PROMPT = 3
24+
MLLAMA_IMAGE_TOKEN_ID = 128256
2425

2526
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
2627

@@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
396397
)
397398

398399

400+
@large_gpu_test(min_gb=48)
401+
@pytest.mark.core_model
402+
@pytest.mark.parametrize("model", models)
403+
@pytest.mark.parametrize("dtype", ["bfloat16"])
404+
@pytest.mark.parametrize("max_tokens", [32])
405+
def test_explicit_implicit_prompt(
406+
image_assets: _ImageAssets,
407+
model: str,
408+
dtype: str,
409+
max_tokens: int,
410+
):
411+
stop_sign = image_assets[0].pil_image
412+
# yapf: disable
413+
prompts = [
414+
# explicit prompt
415+
{
416+
"encoder_prompt": {
417+
"prompt": "<|image|>",
418+
"multi_modal_data": {"image": stop_sign},
419+
},
420+
"decoder_prompt": {
421+
"prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501
422+
}
423+
},
424+
{
425+
"encoder_prompt": "Not <|image|>",
426+
"decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
427+
},
428+
# implicit prompt
429+
{
430+
"prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
431+
"multi_modal_data": {"image": stop_sign},
432+
},
433+
{
434+
"prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
435+
},
436+
]
437+
# yapf: enable
438+
llm = LLM(
439+
model=model,
440+
dtype=dtype,
441+
max_model_len=4096,
442+
max_num_seqs=2,
443+
tensor_parallel_size=1,
444+
enforce_eager=True,
445+
)
446+
sampling_params = SamplingParams(
447+
temperature=0,
448+
max_tokens=max_tokens,
449+
)
450+
outputs = llm.generate(prompts, sampling_params)
451+
n_prompts = len(prompts)
452+
explicit_outputs = outputs[:n_prompts // 2]
453+
implicit_outputs = outputs[n_prompts // 2:]
454+
for exp_output, imp_output in zip(explicit_outputs, implicit_outputs):
455+
assert exp_output.outputs[0].text == imp_output.outputs[0].text
456+
457+
399458
@large_gpu_test(min_gb=48)
400459
@pytest.mark.core_model
401460
@pytest.mark.parametrize("model", models)
@@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
458517
images=images)
459518

460519

520+
class DummyModel:
521+
image_token_id = MLLAMA_IMAGE_TOKEN_ID
522+
523+
461524
@pytest.mark.core_model
462525
@pytest.mark.parametrize(
463526
"input_indices_and_output",
@@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None:
499562
use_cuda_graph=False,
500563
)
501564

502-
dummy: dict[str, str] = {}
565+
dummy = DummyModel()
503566

504567
cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
505568
.get_cross_attention_mask(dummy,
@@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
556619
use_cuda_graph=False,
557620
)
558621

559-
dummy: dict[str, str] = {}
622+
dummy = DummyModel()
560623

561624
full_text_row_masked_out_mask = MllamaForConditionalGeneration\
562625
.get_full_text_row_masked_out_mask(dummy,

tests/models/multimodal/processing/test_common.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ def _test_processing_correctness(
8585
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
8686
}
8787

88+
tokenizer_encode_kwargs = {}
89+
if model_config.hf_config.model_type == "mllama":
90+
# For Mllama, tokenizer will always add bos_token at the beginning of
91+
# prompt by default, causing hf_processor outputs incorrect token ids.
92+
# So we need use `add_special_tokens=False` here to leave bos_token
93+
# to be added by the processor.
94+
tokenizer_encode_kwargs = {"add_special_tokens": False}
95+
8896
for batch_idx in range(num_batches):
8997
mm_data = {
9098
k:
@@ -122,7 +130,7 @@ def _test_processing_correctness(
122130
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
123131

124132
baseline_tokenized_result = baseline_processor.apply(
125-
tokenizer.encode(prompt),
133+
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
126134
mm_data=mm_data,
127135
hf_processor_mm_kwargs={},
128136
)
@@ -131,7 +139,7 @@ def _test_processing_correctness(
131139
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
132140

133141
cached_tokenized_result = cached_processor.apply(
134-
tokenizer.encode(prompt),
142+
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
135143
mm_data=mm_data,
136144
hf_processor_mm_kwargs={},
137145
)
@@ -155,6 +163,7 @@ def _test_processing_correctness(
155163
"llava-hf/llava-v1.6-mistral-7b-hf",
156164
"llava-hf/LLaVA-NeXT-Video-7B-hf",
157165
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
166+
"meta-llama/Llama-3.2-11B-Vision-Instruct",
158167
"TIGER-Lab/Mantis-8B-siglip-llama3",
159168
"mistral-community/pixtral-12b",
160169
"openbmb/MiniCPM-o-2_6",

vllm/inputs/preprocess.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import asyncio
4-
from typing import List, Mapping, Optional, Union
4+
from typing import List, Mapping, Optional, Tuple, Union, cast
55

66
from typing_extensions import assert_never
77

88
from vllm.config import ModelConfig
99
from vllm.logger import init_logger
1010
from vllm.lora.request import LoRARequest
1111
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
12-
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
12+
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
13+
MultiModalInputs)
1314
from vllm.prompt_adapter.request import PromptAdapterRequest
1415
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
1516

@@ -495,6 +496,51 @@ def _build_enc_dec_llm_inputs(
495496
decoder=decoder_inputs,
496497
)
497498

499+
def _separate_enc_dec_inputs_from_mm_processor_outputs(
500+
self,
501+
inputs: SingletonInputs,
502+
decoder_inputs_to_override: Optional[SingletonInputs] = None,
503+
) -> Tuple[SingletonInputs, SingletonInputs]:
504+
"""
505+
For encoder/decoder models only:
506+
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
507+
"""
508+
encoder_inputs: SingletonInputs
509+
decoder_inputs: SingletonInputs
510+
if inputs["type"] == "multimodal":
511+
# Multimodal data inputs
512+
assert ("encoder_prompt" in inputs
513+
and "encoder_prompt_token_ids" in inputs)
514+
inputs = cast(MultiModalEncDecInputs, inputs)
515+
encoder_inputs = token_inputs(
516+
prompt=inputs["encoder_prompt"],
517+
prompt_token_ids=inputs["encoder_prompt_token_ids"],
518+
)
519+
if decoder_inputs_to_override is not None:
520+
decoder_inputs = MultiModalInputs(
521+
type="multimodal",
522+
prompt=decoder_inputs_to_override.get("prompt", ""),
523+
prompt_token_ids=decoder_inputs_to_override[
524+
"prompt_token_ids"],
525+
mm_kwargs=inputs["mm_kwargs"],
526+
mm_placeholders=inputs["mm_placeholders"],
527+
)
528+
else:
529+
decoder_inputs = MultiModalInputs(
530+
type="multimodal",
531+
prompt=inputs["prompt"],
532+
prompt_token_ids=inputs["prompt_token_ids"],
533+
mm_kwargs=inputs["mm_kwargs"],
534+
mm_placeholders=inputs["mm_placeholders"],
535+
)
536+
elif inputs["type"] == "token":
537+
# Text-only inputs
538+
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
539+
decoder_inputs = decoder_inputs_to_override or inputs
540+
else:
541+
assert_never(inputs) # type: ignore[arg-type]
542+
return encoder_inputs, decoder_inputs
543+
498544
def _process_encoder_decoder_prompt(
499545
self,
500546
prompt: PromptType,
@@ -539,21 +585,35 @@ def _process_encoder_decoder_prompt(
539585
prompt["encoder_prompt"],
540586
request_id=request_id,
541587
)
542-
543588
if (decoder_input := prompt["decoder_prompt"]) is None:
544589
decoder_inputs = None
545590
else:
546591
decoder_inputs = self._prompt_to_llm_inputs(
547592
decoder_input,
548593
request_id=request_id,
549594
)
595+
# For multimodal model, override decoder prompt from processor
596+
# with explicit decoder prompt.
597+
if self.model_config.is_multimodal_model and (
598+
self._can_process_multimodal()):
599+
encoder_inputs, decoder_inputs = (
600+
self._separate_enc_dec_inputs_from_mm_processor_outputs(
601+
encoder_inputs, decoder_inputs))
550602
else:
551-
encoder_inputs = self._prompt_to_llm_inputs(
603+
inputs = self._prompt_to_llm_inputs(
552604
prompt,
553605
request_id=request_id,
554606
)
607+
if self.model_config.is_multimodal_model and (
608+
self._can_process_multimodal()):
609+
# Encoder-Decoder Multimodal model
610+
encoder_inputs, decoder_inputs = (
611+
self._separate_enc_dec_inputs_from_mm_processor_outputs(
612+
inputs))
613+
else:
614+
encoder_inputs = inputs
555615

556-
decoder_inputs = None
616+
decoder_inputs = None
557617

558618
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
559619

@@ -583,13 +643,29 @@ async def _process_encoder_decoder_prompt_async(
583643

584644
encoder_inputs, decoder_inputs = await asyncio.gather(
585645
encoder_task, decoder_task)
646+
647+
# For multimodal model, override decoder prompt from processor
648+
# with explicit decoder prompt.
649+
if self.model_config.is_multimodal_model and (
650+
self._can_process_multimodal()):
651+
encoder_inputs, decoder_inputs = (
652+
self._separate_enc_dec_inputs_from_mm_processor_outputs(
653+
encoder_inputs, decoder_inputs))
586654
else:
587-
encoder_inputs = await self._prompt_to_llm_inputs_async(
655+
inputs = await self._prompt_to_llm_inputs_async(
588656
prompt,
589657
request_id=request_id,
590658
)
659+
if self.model_config.is_multimodal_model and (
660+
self._can_process_multimodal()):
661+
# Encoder-Decoder Multimodal model
662+
encoder_inputs, decoder_inputs = (
663+
self._separate_enc_dec_inputs_from_mm_processor_outputs(
664+
inputs))
665+
else:
666+
encoder_inputs = inputs
591667

592-
decoder_inputs = None
668+
decoder_inputs = None
593669

594670
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
595671

vllm/inputs/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def dummy_data_for_profiling(
350350
)
351351
processor = mm_registry.create_processor(model_config, tokenizer)
352352
profiler = MultiModalProfiler(processor)
353-
dummy_data = profiler.get_dummy_data(seq_len)
353+
dummy_data = profiler.get_dummy_data(
354+
seq_len, is_encoder_data=is_encoder_data)
354355
else:
355356
model_cls, _ = get_model_architecture(model_config)
356357
if is_encoder_data:

0 commit comments

Comments
 (0)