Skip to content

Commit 93abf23

Browse files
[VLM] Fully dynamic prompt replacement in merged input processor (#11199)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 9c3dadd commit 93abf23

File tree

12 files changed

+569
-510
lines changed

12 files changed

+569
-510
lines changed

examples/offline_inference_vision_language.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str):
9797
# max_model_len (128k) for this model may cause OOM.
9898
# You may lower either to run this example on lower-end GPUs.
9999

100-
# In this example, we override max_num_seqs to 5 while
101-
# keeping the original context length of 128k.
102-
103100
# num_crops is an override kwarg to the multimodal image processor;
104101
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
105102
# to use 16 for single frame scenarios, and 4 for multi-frame.
@@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str):
113110
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
114111
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
115112
llm = LLM(
116-
model="microsoft/Phi-3-vision-128k-instruct",
113+
model="microsoft/Phi-3.5-vision-instruct",
117114
trust_remote_code=True,
118115
max_model_len=4096,
119116
max_num_seqs=2,

tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
# Wrap lazy imports to avoid initializing CUDA during test collection
1717
@pytest.fixture()
1818
def processor_for_phi3v():
19-
from vllm.model_executor.models.phi3v import Phi3VProcessor
20-
return Phi3VProcessor
19+
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
20+
return Phi3VMultiModalProcessor
2121

2222

2323
@pytest.fixture()

tests/multimodal/test_processing.py

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import cast
22

33
import pytest
4-
from transformers import BatchFeature
54

6-
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
7-
find_text_matches, find_token_matches,
8-
iter_placeholders, iter_token_matches,
5+
from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
6+
_PlaceholderInfo, find_text_matches,
7+
find_token_matches, iter_placeholders,
8+
iter_token_matches,
99
replace_text_matches,
1010
replace_token_matches)
1111
from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -16,7 +16,7 @@
1616
@pytest.mark.parametrize(
1717
("token_ids", "match_ids", "expected"),
1818
[
19-
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
19+
([], [], []),
2020
([], [32000], []),
2121
(
2222
[32000, 32000, 32000],
@@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
8383
"pattern_2": [32000],
8484
},
8585
{
86-
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
86+
"pattern_1": [],
8787
"pattern_2": [],
8888
}
8989
),
@@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
136136
mock_tokenizer = cast(AnyTokenizer, object())
137137

138138
prompt_repls = [
139-
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
139+
PromptReplacement(key, target, []).bind(mock_tokenizer)
140140
for key, target in target_by_key.items()
141141
]
142142
result = find_token_matches(prompt, prompt_repls)
@@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
243243
mock_tokenizer = cast(AnyTokenizer, object())
244244

245245
prompt_repls = [
246-
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
246+
PromptReplacement(key, target, []).bind(mock_tokenizer)
247247
for key, target in target_by_key.items()
248248
]
249249
result = find_text_matches(prompt, prompt_repls)
@@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
276276
"pattern_3": "!",
277277
},
278278
{
279-
# Test whether target is confused with repl_unit
280-
"pattern_1": ("<image><image>", 1),
281-
# Test empty repl_unit
282-
"pattern_2": ("", 1),
283-
# Test multiple repl_count
284-
"pattern_3": ("?", 2),
279+
# Test whether target is confused with replacement
280+
"pattern_1": "<image><image>",
281+
# Test empty replacement
282+
"pattern_2": "",
283+
# Test dynamic replacement (beyond the form of `unit * count`)
284+
"pattern_3": "?!?",
285285
},
286286
),
287287
]
@@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
290290
("mm_count", "expected"),
291291
[
292292
(0, "Image:<image>Image:<image><image>!"),
293-
(1, "<image><image>Image:<image><image>??"),
294-
(2, "<image><image><image><image><image>??"),
293+
(1, "<image><image>Image:<image><image>?!?"),
294+
(2, "<image><image><image><image><image>?!?"),
295295
]
296296
)
297297
# yapf: enable
@@ -306,17 +306,16 @@ def test_find_replace_text(
306306
mock_tokenizer = cast(AnyTokenizer, object())
307307

308308
prompt_repls = [
309-
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
309+
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
310310
for key, target in target_by_key.items()
311311
]
312312
matches = find_text_matches(prompt, prompt_repls)
313313

314314
result = replace_text_matches(
315315
prompt,
316316
matches,
317-
{key: list(range(mm_count))
318-
for key in repl_by_key},
319-
BatchFeature(),
317+
MultiModalDataItems({key: [None] * mm_count
318+
for key in repl_by_key}),
320319
)
321320

322321
# Only displayed on error
@@ -343,12 +342,12 @@ def test_find_replace_text(
343342
"pattern_3": [918],
344343
},
345344
{
346-
# Test whether target is confused with repl_unit
347-
"pattern_1": ([32000, 32000], 1),
348-
# Test empty repl_unit
349-
"pattern_2": ([], 1),
350-
# Test multiple repl_count
351-
"pattern_3": ([1550], 2),
345+
# Test whether target is confused with replacement
346+
"pattern_1": [32000, 32000],
347+
# Test empty replacement
348+
"pattern_2": [],
349+
# Test dynamic replacement (beyond the form of `unit * count`)
350+
"pattern_3": [1550, 918, 1550],
352351
},
353352
),
354353
]
@@ -357,8 +356,8 @@ def test_find_replace_text(
357356
("mm_count", "expected"),
358357
[
359358
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
360-
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]),
361-
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]),
359+
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
360+
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
362361
]
363362
)
364363
# yapf: enable
@@ -373,17 +372,16 @@ def test_find_replace_tokens(
373372
mock_tokenizer = cast(AnyTokenizer, object())
374373

375374
prompt_repls = [
376-
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
375+
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
377376
for key, target in target_by_key.items()
378377
]
379378
matches = find_token_matches(prompt, prompt_repls)
380379

381380
result = replace_token_matches(
382381
prompt,
383382
matches,
384-
{key: list(range(mm_count))
385-
for key in repl_by_key},
386-
BatchFeature(),
383+
MultiModalDataItems({key: [None] * mm_count
384+
for key in repl_by_key}),
387385
)
388386

389387
# Only displayed on error
@@ -399,9 +397,9 @@ def test_find_replace_tokens(
399397
"repl_by_key",
400398
[
401399
{
402-
"pattern_1": ([32000, 32000], 1),
403-
"pattern_2": ([], 1),
404-
"pattern_3": ([1550], 2),
400+
"pattern_1": [32000, 32000],
401+
"pattern_2": [],
402+
"pattern_3": [1550, 918, 1550],
405403
},
406404
],
407405
)
@@ -414,48 +412,47 @@ def test_find_replace_tokens(
414412
_PlaceholderInfo(
415413
modality="pattern_1",
416414
start_idx=6,
417-
unit=[32000, 32000],
418-
unit_count=1,
415+
replacement=[32000, 32000],
419416
),
420417
],
421418
),
422419
(
423-
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550],
420+
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
424421
[
425422
_PlaceholderInfo(
426423
modality="pattern_1",
427424
start_idx=1,
428-
unit=[32000, 32000],
429-
unit_count=1,
425+
replacement=[32000, 32000],
430426
),
431427
_PlaceholderInfo(
432428
modality="pattern_1",
433429
start_idx=5,
434-
unit=[32000, 32000],
435-
unit_count=1,
430+
replacement=[32000, 32000],
436431
),
437432
_PlaceholderInfo(
438433
modality="pattern_3",
439434
start_idx=7,
440-
unit=[1550],
441-
unit_count=2,
435+
replacement=[1550, 918, 1550],
442436
),
443437
],
444438
),
445439
(
446-
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550],
440+
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
447441
[
448442
_PlaceholderInfo(
449443
modality="pattern_1",
450444
start_idx=1,
451-
unit=[32000, 32000],
452-
unit_count=2,
445+
replacement=[32000, 32000],
446+
),
447+
_PlaceholderInfo(
448+
modality="pattern_1",
449+
start_idx=3,
450+
replacement=[32000, 32000],
453451
),
454452
_PlaceholderInfo(
455453
modality="pattern_3",
456454
start_idx=6,
457-
unit=[1550],
458-
unit_count=2,
455+
replacement=[1550, 918, 1550],
459456
),
460457
],
461458
),
@@ -470,11 +467,17 @@ def test_iter_placeholders(
470467
mock_tokenizer = cast(AnyTokenizer, object())
471468

472469
prompt_repls = [
473-
PromptReplacement([], *repl).bind(key, mock_tokenizer)
470+
PromptReplacement(key, [], repl).bind(mock_tokenizer)
474471
for key, repl in repl_by_key.items()
475472
]
476473

477-
result = list(iter_placeholders(prompt_repls, prompt))
474+
result = list(
475+
iter_placeholders(
476+
prompt_repls,
477+
prompt,
478+
# Effectively match all occurrences in the prompt
479+
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
480+
))
478481

479482
# Only displayed on error
480483
print("result:", result)

tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import torch
44

55
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
6-
LlavaProcessor,
6+
LlavaMultiModalProcessor,
77
get_max_llava_image_tokens)
88
from vllm.model_executor.sampling_metadata import SamplingMetadata
99
from vllm.multimodal import MULTIMODAL_REGISTRY
1010

1111

1212
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
13-
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
13+
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
1414
class MyLlava(LlavaForConditionalGeneration):
1515

1616
def compute_logits(

0 commit comments

Comments
 (0)