@@ -511,8 +511,35 @@ def iter_token_matches(
511511 start_idx += 1
512512
513513
514+ def replace_token_matches (
515+ token_ids : list [int ],
516+ match_ids : list [int ],
517+ new_ids : list [int ],
518+ ) -> list [int ]:
519+ """
520+ Replace each occurrence of :code:`match_ids` in :code:`token_ids`
521+ with :code:`new_ids`.
522+
523+ Note that empty matches are ignored.
524+ """
525+ out_seqs = list [list [int ]]()
526+ prev_end_idx = 0
527+
528+ for match in iter_token_matches (token_ids , match_ids ):
529+ start_idx = match .start_idx
530+ end_idx = match .end_idx
531+
532+ out_seqs .append (token_ids [prev_end_idx :start_idx ])
533+ out_seqs .append (new_ids )
534+ prev_end_idx = end_idx
535+
536+ out_seqs .append (token_ids [prev_end_idx :])
537+
538+ return flatten_2d_lists (out_seqs )
539+
540+
514541@dataclass (repr = False )
515- class _PromptTargetMatch (ABC ):
542+ class PromptTargetMatch (ABC ):
516543 _origin : BoundPromptUpdate
517544
518545 @property
@@ -535,7 +562,7 @@ def __repr__(self) -> str:
535562
536563
537564@dataclass (repr = False )
538- class _PromptTargetIndexMatch (_PromptTargetMatch ):
565+ class _PromptTargetIndexMatch (PromptTargetMatch ):
539566 match_idx : int
540567
541568 @property
@@ -548,7 +575,7 @@ def end_idx(self) -> int:
548575
549576
550577@dataclass (repr = False )
551- class _PromptTargetTokenMatch (_PromptTargetMatch ):
578+ class _PromptTargetTokenMatch (PromptTargetMatch ):
552579 match : _TokenMatch
553580
554581 @property
@@ -561,7 +588,7 @@ def end_idx(self) -> int:
561588
562589
563590@dataclass (repr = False )
564- class _PromptTargetTextMatch (_PromptTargetMatch ):
591+ class _PromptTargetTextMatch (PromptTargetMatch ):
565592 match : re .Match [str ]
566593
567594 @property
@@ -594,7 +621,7 @@ def to_range(self) -> PlaceholderRange:
594621def find_token_matches (
595622 prompt : list [int ],
596623 prompt_updates : Sequence [BoundPromptUpdate ],
597- ) -> Sequence [_PromptTargetMatch ]:
624+ ) -> Sequence [PromptTargetMatch ]:
598625 """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
599626
600627 def get_matches (update : BoundPromptUpdate ):
@@ -620,7 +647,7 @@ def get_matches(update: BoundPromptUpdate):
620647def find_text_matches (
621648 prompt : str ,
622649 prompt_updates : Sequence [BoundPromptUpdate ],
623- ) -> Sequence [_PromptTargetMatch ]:
650+ ) -> Sequence [PromptTargetMatch ]:
624651 """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
625652
626653 def get_matches (update : BoundPromptUpdate ):
@@ -645,15 +672,15 @@ def get_matches(update: BoundPromptUpdate):
645672
646673def _resolve_matches (
647674 prompt : PromptSeq ,
648- mm_matches : Mapping [str , Sequence [_PromptTargetMatch ]],
649- ) -> list [_PromptTargetMatch ]:
675+ mm_matches : Mapping [str , Sequence [PromptTargetMatch ]],
676+ ) -> list [PromptTargetMatch ]:
650677 """
651678 Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
652679 and sort them such that earlier matches take priority over later ones.
653680 """
654681 matches = [m for matches in mm_matches .values () for m in matches ]
655682
656- seen_matches : list [Optional [_PromptTargetMatch ]] = [None ] * len (prompt )
683+ seen_matches : list [Optional [PromptTargetMatch ]] = [None ] * len (prompt )
657684
658685 for match in matches :
659686 for idx in range (match .start_idx , match .end_idx ):
@@ -669,7 +696,7 @@ def _resolve_matches(
669696
670697def _apply_matches (
671698 prompt : _S ,
672- mm_matches : Mapping [str , Sequence [_PromptTargetMatch ]],
699+ mm_matches : Mapping [str , Sequence [PromptTargetMatch ]],
673700 mm_item_counts : Mapping [str , int ],
674701) -> list [_S ]:
675702 """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
@@ -718,7 +745,7 @@ def _apply_matches(
718745
719746def apply_token_matches (
720747 prompt : list [int ],
721- mm_matches : Mapping [str , Sequence [_PromptTargetMatch ]],
748+ mm_matches : Mapping [str , Sequence [PromptTargetMatch ]],
722749 mm_item_counts : Mapping [str , int ],
723750) -> list [int ]:
724751 """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
@@ -732,7 +759,7 @@ def apply_token_matches(
732759
733760def apply_text_matches (
734761 prompt : str ,
735- mm_matches : Mapping [str , Sequence [_PromptTargetMatch ]],
762+ mm_matches : Mapping [str , Sequence [PromptTargetMatch ]],
736763 mm_item_counts : Mapping [str , int ],
737764) -> str :
738765 """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
@@ -1055,14 +1082,14 @@ def _get_prompt_updates(
10551082 Given the original multi-modal items for this modality
10561083 and HF-processed data, output the updates to perform.
10571084
1058- Notes:
1059- - You should not assume that HF processor always performs prompt
1060- updates: in :meth:`_apply_hf_processor_missing`, this method
1061- is called on text-only and multimodal-only inputs separately,
1062- instead of passing them in the same call.
1063- - The update information returned by this method is also used to
1064- determine the placeholder token positions for each multi-modal
1065- item.
1085+ The information returned by this method is used to update token inputs
1086+ which bypass the HF processor. It is also used to update the output of
1087+ HF processor if the HF process does not apply prompt updates to text
1088+ inputs.
1089+
1090+ Moreover, this information is critical to determine the token positions
1091+ in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
1092+ for each multi-modal item.
10661093 """
10671094 raise NotImplementedError
10681095
@@ -1357,6 +1384,22 @@ def _bind_and_group_updates(
13571384 it = (update .bind (tokenizer ) for update in prompt_updates )
13581385 return dict (full_groupby_modality (it ))
13591386
1387+ def _apply_token_matches (
1388+ self ,
1389+ prompt : list [int ],
1390+ mm_matches : Mapping [str , Sequence [PromptTargetMatch ]],
1391+ mm_item_counts : Mapping [str , int ],
1392+ ) -> list [int ]:
1393+ return apply_token_matches (prompt , mm_matches , mm_item_counts )
1394+
1395+ def _apply_text_matches (
1396+ self ,
1397+ prompt : str ,
1398+ mm_matches : Mapping [str , Sequence [PromptTargetMatch ]],
1399+ mm_item_counts : Mapping [str , int ],
1400+ ) -> str :
1401+ return apply_text_matches (prompt , mm_matches , mm_item_counts )
1402+
13601403 def _apply_prompt_updates (
13611404 self ,
13621405 token_ids : list [int ],
@@ -1388,7 +1431,7 @@ def _apply_prompt_updates(
13881431 mm_match_counts .get (modality , 0 ) >= item_count
13891432 for modality , item_count in mm_item_counts .items ()
13901433 ): # yapf: disable
1391- token_ids = apply_token_matches (
1434+ token_ids = self . _apply_token_matches (
13921435 token_ids ,
13931436 mm_token_matches ,
13941437 mm_item_counts ,
@@ -1406,7 +1449,7 @@ def _apply_prompt_updates(
14061449 modality : find_text_matches (text , updates )
14071450 for modality , updates in mm_prompt_updates .items ()
14081451 }
1409- text = apply_text_matches (
1452+ text = self . _apply_text_matches (
14101453 text ,
14111454 mm_text_matches ,
14121455 mm_item_counts ,
0 commit comments