Skip to content

Commit 06af9e0

Browse files
committed
Fix prompt replacement
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent c145f49 commit 06af9e0

File tree

3 files changed

+167
-27
lines changed

3 files changed

+167
-27
lines changed

tests/multimodal/test_processing.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
apply_token_matches,
2020
find_mm_placeholders,
2121
find_text_matches, find_token_matches,
22-
iter_token_matches)
22+
iter_token_matches,
23+
replace_token_matches)
2324
# yapf: enable
2425
from vllm.multimodal.profiling import MultiModalProfiler
2526
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
@@ -89,6 +90,58 @@ def test_iter_token_matches(token_ids, match_ids, expected):
8990
assert all(match_len == len(match_ids) for match_len in match_lens)
9091

9192

93+
# yapf: disable
94+
@pytest.mark.parametrize(
95+
("token_ids", "match_ids", "new_ids", "expected"),
96+
[
97+
([], [], [-1], []),
98+
([], [32000], [-1], []),
99+
(
100+
[32000, 32000, 32000],
101+
[32000],
102+
[-1],
103+
[-1, -1, -1],
104+
),
105+
(
106+
[32000, 32000, 32000],
107+
[32000, 32000],
108+
[-1],
109+
[-1, 32000],
110+
),
111+
(
112+
[32000, 32000, 32000],
113+
[32000, 32000, 32000],
114+
[-1],
115+
[-1],
116+
),
117+
(
118+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
119+
[28747, 32000],
120+
[-1],
121+
[9833, -1, 32000, 32000, 9833, -1, 32000, 918],
122+
),
123+
(
124+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
125+
[28747, 32000, 32000, 32000],
126+
[-1],
127+
[9833, -1, 9833, 28747, 32000, 32000, 918],
128+
),
129+
(
130+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
131+
[28747, 0, 32000],
132+
[-1],
133+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
134+
),
135+
],
136+
)
137+
# yapf: enable
138+
def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
139+
result = replace_token_matches(token_ids, match_ids, new_ids)
140+
141+
# Manually constructed results
142+
assert result == expected
143+
144+
92145
# yapf: disable
93146
@pytest.mark.parametrize(
94147
("prompt", "target_by_key", "expected_by_key"),

vllm/model_executor/models/gemma3_mm.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
2121
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
2222
MultiModalDataItems)
23+
# yapf: disable
2324
from vllm.multimodal.processing import (BaseMultiModalProcessor,
2425
BaseProcessingInfo, BoundPromptUpdate,
2526
PlaceholderFeaturesInfo,
26-
PromptReplacement, PromptUpdate,
27-
PromptUpdateDetails, encode_tokens,
28-
find_mm_placeholders)
27+
PromptReplacement, PromptTargetMatch,
28+
PromptUpdate, PromptUpdateDetails,
29+
encode_tokens, find_mm_placeholders,
30+
replace_token_matches)
31+
# yapf: enable
2932
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3033
from vllm.sequence import IntermediateTensors
3134
from vllm.utils import flatten_2d_lists
@@ -320,6 +323,7 @@ def _call_hf_processor(
320323
len(image_repl_feature_tokens)
321324
for image_repl_feature_tokens in image_repls_feature_tokens
322325
]
326+
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
323327

324328
vocab = tokenizer.get_vocab()
325329
image_token_id = vocab[tokenizer.image_token]
@@ -337,7 +341,6 @@ def _call_hf_processor(
337341
for size in image_sizes
338342
]
339343
processed_outputs["num_crops"] = torch.tensor(num_crops)
340-
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
341344

342345
return processed_outputs
343346

@@ -383,6 +386,47 @@ def get_replacement_gemma3(item_idx: int):
383386
)
384387
]
385388

389+
def _apply_token_matches(
390+
self,
391+
prompt: list[int],
392+
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
393+
mm_item_counts: Mapping[str, int],
394+
) -> list[int]:
395+
token_ids = super()._apply_token_matches(
396+
prompt,
397+
mm_matches,
398+
mm_item_counts,
399+
)
400+
401+
# "\n\n\n" and "\n\n\n\n" are single tokens
402+
# Since our replacement can insert "\n\n" next to "\n"
403+
# tokens, we have to combine them to be consistent with
404+
# the output of the tokenizer
405+
tokenizer = self.info.get_tokenizer()
406+
vocab = tokenizer.get_vocab()
407+
newline_1 = vocab["\n"]
408+
newline_2 = vocab["\n\n"]
409+
newline_3 = vocab["\n\n\n"]
410+
newline_4 = vocab["\n\n\n\n"]
411+
412+
token_ids = replace_token_matches(
413+
token_ids,
414+
[newline_1, newline_2],
415+
[newline_3],
416+
)
417+
token_ids = replace_token_matches(
418+
token_ids,
419+
[newline_2, newline_1],
420+
[newline_3],
421+
)
422+
token_ids = replace_token_matches(
423+
token_ids,
424+
[newline_2, newline_2],
425+
[newline_4],
426+
)
427+
428+
return token_ids
429+
386430
def _find_mm_placeholders(
387431
self,
388432
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],

vllm/multimodal/processing.py

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,35 @@ def iter_token_matches(
511511
start_idx += 1
512512

513513

514+
def replace_token_matches(
515+
token_ids: list[int],
516+
match_ids: list[int],
517+
new_ids: list[int],
518+
) -> list[int]:
519+
"""
520+
Replace each occurrence of :code:`match_ids` in :code:`token_ids`
521+
with :code:`new_ids`.
522+
523+
Note that empty matches are ignored.
524+
"""
525+
out_seqs = list[list[int]]()
526+
prev_end_idx = 0
527+
528+
for match in iter_token_matches(token_ids, match_ids):
529+
start_idx = match.start_idx
530+
end_idx = match.end_idx
531+
532+
out_seqs.append(token_ids[prev_end_idx:start_idx])
533+
out_seqs.append(new_ids)
534+
prev_end_idx = end_idx
535+
536+
out_seqs.append(token_ids[prev_end_idx:])
537+
538+
return flatten_2d_lists(out_seqs)
539+
540+
514541
@dataclass(repr=False)
515-
class _PromptTargetMatch(ABC):
542+
class PromptTargetMatch(ABC):
516543
_origin: BoundPromptUpdate
517544

518545
@property
@@ -535,7 +562,7 @@ def __repr__(self) -> str:
535562

536563

537564
@dataclass(repr=False)
538-
class _PromptTargetIndexMatch(_PromptTargetMatch):
565+
class _PromptTargetIndexMatch(PromptTargetMatch):
539566
match_idx: int
540567

541568
@property
@@ -548,7 +575,7 @@ def end_idx(self) -> int:
548575

549576

550577
@dataclass(repr=False)
551-
class _PromptTargetTokenMatch(_PromptTargetMatch):
578+
class _PromptTargetTokenMatch(PromptTargetMatch):
552579
match: _TokenMatch
553580

554581
@property
@@ -561,7 +588,7 @@ def end_idx(self) -> int:
561588

562589

563590
@dataclass(repr=False)
564-
class _PromptTargetTextMatch(_PromptTargetMatch):
591+
class _PromptTargetTextMatch(PromptTargetMatch):
565592
match: re.Match[str]
566593

567594
@property
@@ -594,7 +621,7 @@ def to_range(self) -> PlaceholderRange:
594621
def find_token_matches(
595622
prompt: list[int],
596623
prompt_updates: Sequence[BoundPromptUpdate],
597-
) -> Sequence[_PromptTargetMatch]:
624+
) -> Sequence[PromptTargetMatch]:
598625
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
599626

600627
def get_matches(update: BoundPromptUpdate):
@@ -620,7 +647,7 @@ def get_matches(update: BoundPromptUpdate):
620647
def find_text_matches(
621648
prompt: str,
622649
prompt_updates: Sequence[BoundPromptUpdate],
623-
) -> Sequence[_PromptTargetMatch]:
650+
) -> Sequence[PromptTargetMatch]:
624651
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
625652

626653
def get_matches(update: BoundPromptUpdate):
@@ -645,15 +672,15 @@ def get_matches(update: BoundPromptUpdate):
645672

646673
def _resolve_matches(
647674
prompt: PromptSeq,
648-
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
649-
) -> list[_PromptTargetMatch]:
675+
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
676+
) -> list[PromptTargetMatch]:
650677
"""
651678
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
652679
and sort them such that earlier matches take priority over later ones.
653680
"""
654681
matches = [m for matches in mm_matches.values() for m in matches]
655682

656-
seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt)
683+
seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
657684

658685
for match in matches:
659686
for idx in range(match.start_idx, match.end_idx):
@@ -669,7 +696,7 @@ def _resolve_matches(
669696

670697
def _apply_matches(
671698
prompt: _S,
672-
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
699+
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
673700
mm_item_counts: Mapping[str, int],
674701
) -> list[_S]:
675702
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
@@ -718,7 +745,7 @@ def _apply_matches(
718745

719746
def apply_token_matches(
720747
prompt: list[int],
721-
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
748+
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
722749
mm_item_counts: Mapping[str, int],
723750
) -> list[int]:
724751
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
@@ -732,7 +759,7 @@ def apply_token_matches(
732759

733760
def apply_text_matches(
734761
prompt: str,
735-
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
762+
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
736763
mm_item_counts: Mapping[str, int],
737764
) -> str:
738765
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
@@ -1055,14 +1082,14 @@ def _get_prompt_updates(
10551082
Given the original multi-modal items for this modality
10561083
and HF-processed data, output the updates to perform.
10571084
1058-
Notes:
1059-
- You should not assume that HF processor always performs prompt
1060-
updates: in :meth:`_apply_hf_processor_missing`, this method
1061-
is called on text-only and multimodal-only inputs separately,
1062-
instead of passing them in the same call.
1063-
- The update information returned by this method is also used to
1064-
determine the placeholder token positions for each multi-modal
1065-
item.
1085+
The information returned by this method is used to update token inputs
1086+
which bypass the HF processor. It is also used to update the output of
1087+
HF processor if the HF process does not apply prompt updates to text
1088+
inputs.
1089+
1090+
Moreover, this information is critical to determine the token positions
1091+
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
1092+
for each multi-modal item.
10661093
"""
10671094
raise NotImplementedError
10681095

@@ -1357,6 +1384,22 @@ def _bind_and_group_updates(
13571384
it = (update.bind(tokenizer) for update in prompt_updates)
13581385
return dict(full_groupby_modality(it))
13591386

1387+
def _apply_token_matches(
1388+
self,
1389+
prompt: list[int],
1390+
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
1391+
mm_item_counts: Mapping[str, int],
1392+
) -> list[int]:
1393+
return apply_token_matches(prompt, mm_matches, mm_item_counts)
1394+
1395+
def _apply_text_matches(
1396+
self,
1397+
prompt: str,
1398+
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
1399+
mm_item_counts: Mapping[str, int],
1400+
) -> str:
1401+
return apply_text_matches(prompt, mm_matches, mm_item_counts)
1402+
13601403
def _apply_prompt_updates(
13611404
self,
13621405
token_ids: list[int],
@@ -1388,7 +1431,7 @@ def _apply_prompt_updates(
13881431
mm_match_counts.get(modality, 0) >= item_count
13891432
for modality, item_count in mm_item_counts.items()
13901433
): # yapf: disable
1391-
token_ids = apply_token_matches(
1434+
token_ids = self._apply_token_matches(
13921435
token_ids,
13931436
mm_token_matches,
13941437
mm_item_counts,
@@ -1406,7 +1449,7 @@ def _apply_prompt_updates(
14061449
modality: find_text_matches(text, updates)
14071450
for modality, updates in mm_prompt_updates.items()
14081451
}
1409-
text = apply_text_matches(
1452+
text = self._apply_text_matches(
14101453
text,
14111454
mm_text_matches,
14121455
mm_item_counts,

0 commit comments

Comments
 (0)