Skip to content

Commit c3eca7b

Browse files
DarkLight1337Akshat-Tripathi
authored andcommitted
[VLM][Bugfix] Enable specifying prompt target via index (vllm-project#14038)
1 parent 367db4b commit c3eca7b

File tree

5 files changed

+432
-59
lines changed

5 files changed

+432
-59
lines changed

tests/multimodal/test_processing.py

Lines changed: 256 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# yapf conflicts with isort for this block
1515
# yapf: disable
1616
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
17-
PromptInsertion, PromptReplacement,
18-
apply_text_matches,
17+
PromptIndexTargets, PromptInsertion,
18+
PromptReplacement, apply_text_matches,
1919
apply_token_matches,
2020
find_mm_placeholders,
2121
find_text_matches, find_token_matches,
@@ -98,10 +98,20 @@ def test_iter_token_matches(token_ids, match_ids, expected):
9898
{
9999
"pattern_1": [],
100100
"pattern_2": [32000],
101+
"pattern_3": PromptIndexTargets.start(),
102+
"pattern_4": PromptIndexTargets.prefix([32000]),
103+
"pattern_5": PromptIndexTargets.end(),
101104
},
102105
{
103106
"pattern_1": [],
104107
"pattern_2": [],
108+
"pattern_3": [
109+
{ "start_idx": 0, "end_idx": 0 },
110+
],
111+
"pattern_4": [],
112+
"pattern_5": [
113+
{ "start_idx": 0, "end_idx": 0 },
114+
],
105115
},
106116
),
107117
(
@@ -110,6 +120,9 @@ def test_iter_token_matches(token_ids, match_ids, expected):
110120
"pattern_1": [32000],
111121
"pattern_2": [32000, 32000],
112122
"pattern_3": [32000, 32000, 32000],
123+
"pattern_4": PromptIndexTargets.start(),
124+
"pattern_5": PromptIndexTargets.prefix([32000]),
125+
"pattern_6": PromptIndexTargets.end(),
113126
},
114127
{
115128
"pattern_1": [
@@ -125,6 +138,15 @@ def test_iter_token_matches(token_ids, match_ids, expected):
125138
"pattern_3": [
126139
{ "start_idx": 0, "end_idx": 3 },
127140
],
141+
"pattern_4": [
142+
{ "start_idx": 0, "end_idx": 0 },
143+
],
144+
"pattern_5": [
145+
{ "start_idx": 1, "end_idx": 1 },
146+
],
147+
"pattern_6": [
148+
{ "start_idx": 4, "end_idx": 4 },
149+
],
128150
},
129151
),
130152
(
@@ -133,6 +155,9 @@ def test_iter_token_matches(token_ids, match_ids, expected):
133155
"pattern_1": [28747, 32000],
134156
"pattern_2": [28747, 32000, 32000, 32000],
135157
"pattern_3": [28747, 0, 32000],
158+
"pattern_4": PromptIndexTargets.start(),
159+
"pattern_5": PromptIndexTargets.prefix([28747, 32000]),
160+
"pattern_6": PromptIndexTargets.end(),
136161
},
137162
{
138163
"pattern_1": [
@@ -143,6 +168,13 @@ def test_iter_token_matches(token_ids, match_ids, expected):
143168
{ "start_idx": 1, "end_idx": 5 },
144169
],
145170
"pattern_3": [],
171+
"pattern_4": [
172+
{ "start_idx": 0, "end_idx": 0 },
173+
],
174+
"pattern_5": [],
175+
"pattern_6": [
176+
{ "start_idx": 10, "end_idx": 10 },
177+
],
146178
},
147179
),
148180
],
@@ -189,10 +221,20 @@ def test_find_token_matches(
189221
{
190222
"pattern_1": "",
191223
"pattern_2": "<image>",
224+
"pattern_3": PromptIndexTargets.start(),
225+
"pattern_4": PromptIndexTargets.prefix("<image>"),
226+
"pattern_5": PromptIndexTargets.end(),
192227
},
193228
{
194229
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
195230
"pattern_2": [],
231+
"pattern_3": [
232+
{ "start_idx": 0, "end_idx": 0 },
233+
],
234+
"pattern_4": [],
235+
"pattern_5": [
236+
{ "start_idx": 0, "end_idx": 0 },
237+
],
196238
}
197239
),
198240
(
@@ -201,6 +243,9 @@ def test_find_token_matches(
201243
"pattern_1": "<image>",
202244
"pattern_2": "<image><image>",
203245
"pattern_3": "<image><image><image>",
246+
"pattern_4": PromptIndexTargets.start(),
247+
"pattern_5": PromptIndexTargets.prefix("<image>"),
248+
"pattern_6": PromptIndexTargets.end(),
204249
},
205250
{
206251
"pattern_1": [
@@ -216,6 +261,15 @@ def test_find_token_matches(
216261
"pattern_3": [
217262
{ "start_idx": 0, "end_idx": 21 },
218263
],
264+
"pattern_4": [
265+
{ "start_idx": 0, "end_idx": 0 },
266+
],
267+
"pattern_5": [
268+
{ "start_idx": 7, "end_idx": 7 },
269+
],
270+
"pattern_6": [
271+
{ "start_idx": 28, "end_idx": 28 },
272+
],
219273
},
220274
),
221275
(
@@ -224,6 +278,9 @@ def test_find_token_matches(
224278
"pattern_1": "Image:<image>",
225279
"pattern_2": "Image:<image><image><image>",
226280
"pattern_3": "Image:<unk><image>",
281+
"pattern_4": PromptIndexTargets.start(),
282+
"pattern_5": PromptIndexTargets.prefix("Image:<image>"),
283+
"pattern_6": PromptIndexTargets.end(),
227284
},
228285
{
229286
"pattern_1": [
@@ -234,6 +291,15 @@ def test_find_token_matches(
234291
{ "start_idx": 0, "end_idx": 27 },
235292
],
236293
"pattern_3": [],
294+
"pattern_4": [
295+
{ "start_idx": 0, "end_idx": 0 },
296+
],
297+
"pattern_5": [
298+
{ "start_idx": 13, "end_idx": 13 },
299+
],
300+
"pattern_6": [
301+
{ "start_idx": 48, "end_idx": 48 },
302+
],
237303
},
238304
),
239305
# Test regex escape
@@ -325,6 +391,100 @@ def test_find_text_matches(
325391
},
326392
},
327393
),
394+
# Test index targets
395+
(
396+
"",
397+
{
398+
"pattern_1": PromptIndexTargets.start(),
399+
"pattern_2": PromptIndexTargets.prefix("<image>"),
400+
"pattern_3": PromptIndexTargets.end(),
401+
},
402+
{
403+
"pattern_1": "1",
404+
"pattern_2": "2",
405+
"pattern_3": "3",
406+
},
407+
{
408+
PromptInsertion: {
409+
0: "",
410+
1: "13",
411+
2: "1133",
412+
},
413+
PromptReplacement: {
414+
0: "",
415+
1: "13",
416+
2: "1133",
417+
},
418+
},
419+
),
420+
(
421+
"<image>",
422+
{
423+
"pattern_1": PromptIndexTargets.start(),
424+
"pattern_2": PromptIndexTargets.prefix("<image>"),
425+
"pattern_3": PromptIndexTargets.end(),
426+
},
427+
{
428+
"pattern_1": "1",
429+
"pattern_2": "2",
430+
"pattern_3": "3",
431+
},
432+
{
433+
PromptInsertion: {
434+
0: "<image>",
435+
1: "1<image>23",
436+
2: "11<image>2233",
437+
},
438+
PromptReplacement: {
439+
0: "<image>",
440+
1: "1<image>23",
441+
2: "11<image>2233",
442+
},
443+
},
444+
),
445+
# Test different replacement per item
446+
(
447+
"<image><image><image>",
448+
{
449+
"pattern_1": "<image>",
450+
},
451+
{
452+
"pattern_1": lambda idx: str(idx + 1),
453+
},
454+
{
455+
PromptInsertion: {
456+
0: "<image><image><image>",
457+
1: "<image>1<image><image>",
458+
2: "<image>12<image><image>",
459+
},
460+
PromptReplacement: {
461+
0: "<image><image><image>",
462+
1: "1<image><image>",
463+
2: "12<image>",
464+
},
465+
},
466+
),
467+
(
468+
"<image><image><image>",
469+
{
470+
"pattern_1": PromptIndexTargets.prefix("<image>"),
471+
},
472+
{
473+
"pattern_1": lambda idx: str(idx + 1),
474+
},
475+
{
476+
PromptInsertion: {
477+
0: "<image><image><image>",
478+
1: "<image>1<image><image>",
479+
2: "<image>12<image><image>",
480+
},
481+
PromptReplacement: {
482+
0: "<image><image><image>",
483+
1: "<image>1<image><image>",
484+
2: "<image>12<image><image>",
485+
},
486+
},
487+
),
328488
]
329489
)
330490
# yapf: enable
@@ -405,6 +565,100 @@ def test_find_update_text(
405565
},
406566
},
407567
),
568+
# Test index targets
569+
(
570+
[],
571+
{
572+
"pattern_1": PromptIndexTargets.start(),
573+
"pattern_2": PromptIndexTargets.prefix([32000]),
574+
"pattern_3": PromptIndexTargets.end(),
575+
},
576+
{
577+
"pattern_1": [-1],
578+
"pattern_2": [-2],
579+
"pattern_3": [-3],
580+
},
581+
{
582+
PromptInsertion: {
583+
0: [],
584+
1: [-1, -3],
585+
2: [-1, -1, -3, -3],
586+
},
587+
PromptReplacement: {
588+
0: [],
589+
1: [-1, -3],
590+
2: [-1, -1, -3, -3],
591+
},
592+
},
593+
),
594+
(
595+
[32000],
596+
{
597+
"pattern_1": PromptIndexTargets.start(),
598+
"pattern_2": PromptIndexTargets.prefix([32000]),
599+
"pattern_3": PromptIndexTargets.end(),
600+
},
601+
{
602+
"pattern_1": [-1],
603+
"pattern_2": [-2],
604+
"pattern_3": [-3],
605+
},
606+
{
607+
PromptInsertion: {
608+
0: [32000],
609+
1: [-1, 32000, -2, -3],
610+
2: [-1, -1, 32000, -2, -2, -3, -3],
611+
},
612+
PromptReplacement: {
613+
0: [32000],
614+
1: [-1, 32000, -2, -3],
615+
2: [-1, -1, 32000, -2, -2, -3, -3],
616+
},
617+
},
618+
),
619+
# Test different replacement per item
620+
(
621+
[32000, 32000, 32000],
622+
{
623+
"pattern_1": [32000],
624+
},
625+
{
626+
"pattern_1": lambda idx: [-(idx + 1)],
627+
},
628+
{
629+
PromptInsertion: {
630+
0: [32000, 32000, 32000],
631+
1: [32000, -1, 32000, 32000],
632+
2: [32000, -1, -2, 32000, 32000],
633+
},
634+
PromptReplacement: {
635+
0: [32000, 32000, 32000],
636+
1: [-1, 32000, 32000],
637+
2: [-1, -2, 32000],
638+
},
639+
},
640+
),
641+
(
642+
[32000, 32000, 32000],
643+
{
644+
"pattern_1": PromptIndexTargets.prefix([32000]),
645+
},
646+
{
647+
"pattern_1": lambda idx: [-(idx + 1)],
648+
},
649+
{
650+
PromptInsertion: {
651+
0: [32000, 32000, 32000],
652+
1: [32000, -1, 32000, 32000],
653+
2: [32000, -1, -2, 32000, 32000],
654+
},
655+
PromptReplacement: {
656+
0: [32000, 32000, 32000],
657+
1: [32000, -1, 32000, 32000],
658+
2: [32000, -1, -2, 32000, 32000],
659+
},
660+
},
661+
),
408662
]
409663
)
410664
# yapf: enable

vllm/model_executor/models/blip2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
NestedTensors)
2020
from vllm.multimodal.parse import MultiModalDataItems
2121
from vllm.multimodal.processing import (BaseMultiModalProcessor,
22-
BaseProcessingInfo, PromptInsertion,
23-
PromptUpdate)
22+
BaseProcessingInfo, PromptIndexTargets,
23+
PromptInsertion, PromptUpdate)
2424
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
2525
from vllm.sequence import IntermediateTensors
2626

@@ -490,7 +490,7 @@ def _get_prompt_updates(
490490
return [
491491
PromptInsertion(
492492
modality="image",
493-
target="",
493+
target=PromptIndexTargets.start(),
494494
insertion=image_tokens,
495495
)
496496
]

vllm/model_executor/models/florence2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
2626
from vllm.multimodal.processing import (BaseProcessingInfo,
2727
EncDecMultiModalProcessor,
28-
PromptInsertion, PromptUpdate)
28+
PromptIndexTargets, PromptInsertion,
29+
PromptUpdate)
2930
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3031
from vllm.sequence import IntermediateTensors
3132

@@ -864,7 +865,7 @@ def _get_prompt_updates(
864865
return [
865866
PromptInsertion(
866867
modality="image",
867-
target="",
868+
target=PromptIndexTargets.start(),
868869
insertion=image_tokens,
869870
)
870871
]

0 commit comments

Comments
 (0)