Skip to content

Commit fc0f877

Browse files
authored
[Bugfix] Make dummy encoder prompt padding alternative and add missing warnings (#16129)
Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent 0a57386 commit fc0f877

File tree

5 files changed

+108
-4
lines changed

5 files changed

+108
-4
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests for mllama's multimodal preprocessing and profiling."""
3+
import pytest
4+
from transformers import MllamaConfig
5+
6+
from vllm.multimodal import MULTIMODAL_REGISTRY
7+
from vllm.multimodal.profiling import MultiModalProfiler
8+
9+
from ...utils import build_model_context
10+
11+
12+
@pytest.mark.parametrize("model_id",
13+
["meta-llama/Llama-3.2-11B-Vision-Instruct"])
14+
@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072])
15+
@pytest.mark.parametrize("max_num_seqs", [1, 2, 8])
16+
def test_profiling(
17+
model_id: str,
18+
max_model_len: int,
19+
max_num_seqs: int,
20+
):
21+
# regression test for https://github.com/vllm-project/vllm/issues/13929
22+
from vllm.model_executor.models.mllama import calc_token_per_chunk
23+
24+
model_config_kwargs = {
25+
"max_model_len": max_model_len,
26+
}
27+
ctx = build_model_context(
28+
model_id,
29+
model_config_kwargs=model_config_kwargs,
30+
limit_mm_per_prompt={"image": 1},
31+
)
32+
33+
mm_config = ctx.get_mm_config()
34+
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
35+
profiler = MultiModalProfiler(processor)
36+
37+
dummy_encoder_data = profiler.get_encoder_dummy_data(
38+
max_model_len,
39+
mm_counts=mm_config.limit_per_prompt,
40+
)
41+
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
42+
max_model_len,
43+
mm_counts=mm_config.limit_per_prompt,
44+
)
45+
46+
hf_config = ctx.get_hf_config(MllamaConfig)
47+
image_size = hf_config.vision_config.image_size
48+
encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids)
49+
] * max_num_seqs
50+
51+
mm_kwargs = processor.apply(
52+
prompt=dummy_mm_data.prompt_text,
53+
mm_data=dummy_mm_data.mm_data,
54+
hf_processor_mm_kwargs=dict(),
55+
)["mm_kwargs"]
56+
57+
# Get the actual number of encoder tokens for each sample.
58+
# Because attn_metadata.encoder_seq_lens only counts the last
59+
# group of images for each sample, which is used to cheat the
60+
# block manager to allocate blocks for those images only.
61+
# See MllamaMultiModalProcessor for more details.
62+
num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")]
63+
num_tokens_per_tile = calc_token_per_chunk(image_size)
64+
actual_encoder_seq_lens = [
65+
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
66+
]
67+
68+
# simulate mllama image-present prefill.
69+
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
70+
encoder_seq_lens):
71+
assert actual_len >= last_group_len

tests/models/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def build_model_context(
255255
model_id: str,
256256
task: TaskOption = "auto",
257257
dtype: Union[str, torch.dtype] = "auto",
258+
model_config_kwargs: Optional[dict[str, Any]] = None,
258259
mm_processor_kwargs: Optional[dict[str, Any]] = None,
259260
limit_mm_per_prompt: Optional[dict[str, int]] = None,
260261
disable_mm_preprocessor_cache: bool = True,
@@ -274,6 +275,7 @@ def build_model_context(
274275
model_info.check_available_online(on_fail="skip")
275276
model_info.check_transformers_version(on_fail="skip")
276277

278+
model_config_kwargs = model_config_kwargs or {}
277279
model_config = ModelConfig(
278280
model_id,
279281
task=task,
@@ -286,5 +288,6 @@ def build_model_context(
286288
limit_mm_per_prompt=limit_mm_per_prompt,
287289
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
288290
hf_overrides=model_info.hf_overrides,
291+
**model_config_kwargs,
289292
)
290293
return InputContext(model_config)

vllm/model_executor/models/whisper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,10 @@ def _get_data_parser(self) -> MultiModalDataParser:
580580
feature_extractor = self.info.get_feature_extractor()
581581
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
582582

583+
@property
584+
def pad_dummy_encoder_prompt(self) -> bool:
585+
return True
586+
583587
def create_encoder_prompt(
584588
self,
585589
prompt: Union[str, list[int]],

vllm/multimodal/processing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,10 @@ def create_encoder_prompt(
16541654
"""
16551655
raise NotImplementedError
16561656

1657+
@property
1658+
def pad_dummy_encoder_prompt(self) -> bool:
1659+
return False
1660+
16571661
def create_decoder_prompt(
16581662
self,
16591663
prompt: Union[str, list[int]],

vllm/multimodal/profiling.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
1616
MultiModalInputs, MultiModalKwargs,
1717
MultiModalPlaceholderDict)
18-
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
18+
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
19+
EncDecMultiModalProcessor)
1920

2021
logger = init_logger(__name__)
2122

@@ -200,16 +201,37 @@ def get_encoder_dummy_data(
200201
seq_len: int,
201202
mm_counts: Optional[Mapping[str, int]] = None,
202203
) -> DummyEncoderData:
203-
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len, mm_counts)
204+
(
205+
mm_inputs,
206+
total_placeholders_by_modality,
207+
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
204208
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
205209

206210
# For encoder-decoder models, use encoder prompt token ids instead of
207211
# decoder prompt to construct dummy seq_data for encoder profiling.
208212
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
209213

210214
total_len = len(encoder_prompt_token_ids)
211-
num_tokens_to_pad = max(total_len, seq_len) - total_len
212-
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
215+
216+
# Encoder-decoder multimodal models only support v0
217+
if total_len > seq_len:
218+
# `max_num_batched_tokens` is defined by `SchedulerConfig`
219+
logger.warning(
220+
"The encoder sequence length used for profiling ("
221+
"max_num_batched_tokens / max_num_seqs = %d) is too short "
222+
"to hold the multi-modal embeddings in the worst case "
223+
"(%d tokens in total, out of which %s are reserved for "
224+
"multi-modal embeddings). This may cause certain "
225+
"multi-modal inputs to fail during inference, even when "
226+
"the input text is short. To avoid this, you should "
227+
"increase `max_model_len`, reduce `max_num_seqs`, "
228+
"and/or reduce `mm_counts`.", seq_len, total_len,
229+
total_placeholders_by_modality)
230+
231+
processor = cast(EncDecMultiModalProcessor, self.processor)
232+
if processor.pad_dummy_encoder_prompt:
233+
num_tokens_to_pad = max(total_len, seq_len) - total_len
234+
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
213235

214236
return DummyEncoderData(encoder_prompt_token_ids)
215237

0 commit comments

Comments
 (0)