Skip to content

Commit d3f71f1

Browse files
[Refactor] Get prompt updates earlier (#23097)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 5a30bd1 commit d3f71f1

File tree

6 files changed

+84
-69
lines changed

6 files changed

+84
-69
lines changed

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
2626
ImageSize, MultiModalDataItems)
2727
from vllm.multimodal.processing import (BaseMultiModalProcessor,
28-
BaseProcessingInfo, MultiModalHashes,
28+
BaseProcessingInfo,
29+
MultiModalProcessingInfo,
2930
PromptReplacement, PromptUpdate)
3031
from vllm.multimodal.profiling import BaseDummyInputsBuilder
3132
from vllm.sequence import IntermediateTensors
@@ -291,8 +292,7 @@ def _cached_apply_hf_processor(
291292
tokenization_kwargs: Mapping[str, object],
292293
*,
293294
return_mm_hashes: bool,
294-
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
295-
bool]:
295+
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
296296
# The processor logic is different for len(images) <= 2 vs > 2
297297
# Since the processing cache assumes that the processor output is
298298
# invariant of how many images are passed per prompt, we only

vllm/model_executor/models/h2ovl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from vllm.multimodal.inputs import MultiModalKwargsItems
2121
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
2222
MultiModalDataItems)
23-
from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement,
24-
PromptUpdate, PromptUpdateDetails)
23+
from vllm.multimodal.processing import (MultiModalProcessingInfo,
24+
PromptReplacement, PromptUpdate,
25+
PromptUpdateDetails)
2526
from vllm.transformers_utils.tokenizer import AnyTokenizer
2627

2728
from .intern_vit import InternVisionModel
@@ -480,8 +481,7 @@ def _cached_apply_hf_processor(
480481
tokenization_kwargs: Mapping[str, object],
481482
*,
482483
return_mm_hashes: bool,
483-
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
484-
bool]:
484+
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
485485
# The processor logic is different for len(images) <= 1 vs > 1
486486
# Since the processing cache assumes that the processor output is
487487
# invariant of how many images are passed per prompt, we only

vllm/model_executor/models/pixtral.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
4040
MultiModalDataItems)
4141
from vllm.multimodal.processing import (BaseMultiModalProcessor,
42-
BaseProcessingInfo, MultiModalHashes,
42+
BaseProcessingInfo,
43+
MultiModalProcessingInfo,
4344
PromptReplacement, PromptUpdate,
4445
PromptUpdateDetails)
4546
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
@@ -309,14 +310,8 @@ def _cached_apply_hf_processor(
309310
tokenization_kwargs: Mapping[str, object],
310311
*,
311312
return_mm_hashes: bool,
312-
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
313-
bool]:
314-
(
315-
prompt_ids,
316-
mm_kwargs,
317-
mm_hashes,
318-
_,
319-
) = super()._cached_apply_hf_processor(
313+
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
314+
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
320315
prompt=prompt,
321316
mm_data_items=mm_data_items,
322317
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@@ -325,7 +320,7 @@ def _cached_apply_hf_processor(
325320
)
326321

327322
# NOTE: The tokens are already inserted by the chat template
328-
return prompt_ids, mm_kwargs, mm_hashes, True
323+
return prompt_ids, mm_info, True
329324

330325

331326
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,

vllm/model_executor/models/qwen2_5_omni_thinker.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
ModalityDataItems, MultiModalDataItems,
6060
MultiModalDataParser)
6161
from vllm.multimodal.processing import (BaseMultiModalProcessor,
62+
MultiModalPromptUpdates,
6263
PlaceholderFeaturesInfo,
6364
PromptReplacement, PromptUpdate)
6465
from vllm.multimodal.profiling import BaseDummyInputsBuilder
@@ -88,10 +89,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
8889
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
8990
video_grid_sizes = video_grid_thw.prod(-1)
9091

91-
# vllm use `second_per_grid_ts` to compute multimodal rotary embedding
92-
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
93-
if video_second_per_grid is not None:
94-
hf_inputs["second_per_grid_ts"] = video_second_per_grid
92+
num_videos = len(video_grid_sizes)
9593

9694
return dict(
9795
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
@@ -109,6 +107,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
109107
"video", video_grid_sizes),
110108
video_grid_thw=MultiModalFieldConfig.batched("video"),
111109
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
110+
use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
112111
)
113112

114113

@@ -251,6 +250,14 @@ def _call_hf_processor(
251250
if ('audio_feature_lengths' not in hf_inputs
252251
and feature_attention_mask is not None):
253252
hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1)
253+
254+
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
255+
if video_second_per_grid is not None:
256+
hf_inputs["second_per_grid_ts"] = video_second_per_grid
257+
258+
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
259+
hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video)
260+
254261
return hf_inputs
255262

256263
def _get_mm_fields_config(
@@ -263,27 +270,20 @@ def _get_mm_fields_config(
263270
def _maybe_apply_prompt_updates(
264271
self,
265272
mm_items: MultiModalDataItems,
266-
hf_processor_mm_kwargs: Mapping[str, object],
267273
prompt_ids: list[int],
268274
mm_kwargs: MultiModalKwargsItems,
275+
mm_prompt_updates: MultiModalPromptUpdates,
269276
is_update_applied: bool,
270277
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
271278
"""
272279
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
273280
"""
274-
unbound_prompt_updates = self._get_prompt_updates(
275-
mm_items,
276-
hf_processor_mm_kwargs,
277-
mm_kwargs,
278-
)
279-
mm_prompt_updates = self._bind_and_group_updates(
280-
unbound_prompt_updates)
281-
282281
mm_item_counts = mm_items.get_all_counts()
283282
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
284283

285-
use_audio_in_video = hf_processor_mm_kwargs.get(
286-
"use_audio_in_video", False)
284+
use_audio_in_video = (all(
285+
item["use_audio_in_video"].data
286+
for item in mm_kwargs["video"]) if "video" in mm_kwargs else False)
287287

288288
if is_update_applied:
289289
mm_placeholders = self._find_mm_placeholders(
@@ -316,9 +316,6 @@ def _maybe_apply_prompt_updates(
316316
tokenizer = self.info.get_tokenizer()
317317
prompt = decode_tokens(tokenizer, prompt_ids)
318318

319-
if use_audio_in_video:
320-
mm_kwargs["use_audio_in_video"] = True
321-
322319
return prompt_ids, prompt, mm_placeholders
323320

324321
def _get_prompt_updates(

vllm/model_executor/models/voxtral.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
3636
MultiModalDataParser)
3737
from vllm.multimodal.processing import (BaseMultiModalProcessor,
38-
BaseProcessingInfo, MultiModalHashes,
38+
BaseProcessingInfo,
39+
MultiModalProcessingInfo,
3940
PromptReplacement, PromptUpdate)
4041
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
4142
from vllm.sequence import IntermediateTensors
@@ -289,10 +290,8 @@ def _cached_apply_hf_processor(
289290
tokenization_kwargs: Mapping[str, object],
290291
*,
291292
return_mm_hashes: bool,
292-
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
293-
bool]:
294-
prompt_ids, mm_kwargs, mm_hashes, _ = super(
295-
)._cached_apply_hf_processor(
293+
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
294+
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
296295
prompt=prompt,
297296
mm_data_items=mm_data_items,
298297
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@@ -301,7 +300,7 @@ def _cached_apply_hf_processor(
301300
)
302301

303302
# NOTE: The tokens are already inserted by the chat template
304-
return prompt_ids, mm_kwargs, mm_hashes, True
303+
return prompt_ids, mm_info, True
305304

306305
def _get_data_parser(self) -> MultiModalDataParser:
307306
sampling_rate = self.info.get_hf_processor().sampling_rate

vllm/multimodal/processing.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,18 @@ def get_mm_max_tokens_per_item(
989989
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
990990
"""
991991

992+
MultiModalPromptUpdates = dict[str, Sequence[BoundPromptUpdate]]
993+
"""
994+
A collection of prompt updates with a similar structure as
995+
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
996+
"""
997+
998+
999+
class MultiModalProcessingInfo(NamedTuple):
1000+
kwargs: MultiModalKwargsItems
1001+
hashes: Optional[MultiModalHashes]
1002+
prompt_updates: MultiModalPromptUpdates
1003+
9921004

9931005
class BaseMultiModalProcessor(ABC, Generic[_I]):
9941006
"""
@@ -1363,7 +1375,7 @@ def _merge_mm_kwargs(
13631375
cache: ProcessingCache,
13641376
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
13651377
mm_missing_kwargs: MultiModalKwargsItems,
1366-
) -> dict[str, list[MultiModalKwargsItem]]:
1378+
) -> MultiModalKwargsItems:
13671379
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
13681380

13691381
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
@@ -1379,7 +1391,7 @@ def _merge_mm_kwargs(
13791391

13801392
merged_items[modality].append(kw_item)
13811393

1382-
return dict(merged_items)
1394+
return MultiModalKwargsItems(merged_items)
13831395

13841396
def _apply_hf_processor(
13851397
self,
@@ -1389,8 +1401,7 @@ def _apply_hf_processor(
13891401
tokenization_kwargs: Mapping[str, object],
13901402
*,
13911403
return_mm_hashes: bool,
1392-
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
1393-
bool]:
1404+
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
13941405
(
13951406
prompt_ids,
13961407
mm_processed_data,
@@ -1413,7 +1424,21 @@ def _apply_hf_processor(
14131424
tokenization_kwargs)
14141425
if return_mm_hashes else None)
14151426

1416-
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
1427+
unbound_prompt_updates = self._get_prompt_updates(
1428+
mm_data_items,
1429+
hf_processor_mm_kwargs,
1430+
mm_kwargs,
1431+
)
1432+
mm_prompt_updates = self._bind_and_group_updates(
1433+
unbound_prompt_updates)
1434+
1435+
mm_info = MultiModalProcessingInfo(
1436+
kwargs=mm_kwargs,
1437+
hashes=mm_hashes,
1438+
prompt_updates=mm_prompt_updates,
1439+
)
1440+
1441+
return prompt_ids, mm_info, is_update_applied
14171442

14181443
def _cached_apply_hf_processor(
14191444
self,
@@ -1423,8 +1448,7 @@ def _cached_apply_hf_processor(
14231448
tokenization_kwargs: Mapping[str, object],
14241449
*,
14251450
return_mm_hashes: bool,
1426-
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
1427-
bool]:
1451+
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
14281452
"""
14291453
Apply the HF processor on the full prompt text,
14301454
caching the results and reusing cached results.
@@ -1475,18 +1499,27 @@ def _cached_apply_hf_processor(
14751499
hf_processor_mm_kwargs),
14761500
)
14771501

1478-
mm_cache_items_merged = self._merge_mm_kwargs(
1502+
mm_kwargs = self._merge_mm_kwargs(
14791503
cache,
14801504
mm_cache_items_or_hashes=mm_cache_items_or_hashes,
14811505
mm_missing_kwargs=mm_missing_kwargs,
14821506
)
14831507

1484-
mm_kwargs = MultiModalKwargsItems.from_seq([
1485-
item for cache_items in mm_cache_items_merged.values()
1486-
for item in cache_items
1487-
])
1508+
unbound_prompt_updates = self._get_prompt_updates(
1509+
mm_data_items,
1510+
hf_processor_mm_kwargs,
1511+
mm_kwargs,
1512+
)
1513+
mm_prompt_updates = self._bind_and_group_updates(
1514+
unbound_prompt_updates)
1515+
1516+
mm_info = MultiModalProcessingInfo(
1517+
kwargs=mm_kwargs,
1518+
hashes=mm_hashes_to_return,
1519+
prompt_updates=mm_prompt_updates,
1520+
)
14881521

1489-
return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied
1522+
return prompt_ids, mm_info, is_update_applied
14901523

14911524
def _bind_and_group_updates(
14921525
self,
@@ -1626,19 +1659,11 @@ def _validate_mm_placeholders(
16261659
def _maybe_apply_prompt_updates(
16271660
self,
16281661
mm_items: MultiModalDataItems,
1629-
hf_processor_mm_kwargs: Mapping[str, object],
16301662
prompt_ids: list[int],
16311663
mm_kwargs: MultiModalKwargsItems,
1664+
mm_prompt_updates: MultiModalPromptUpdates,
16321665
is_update_applied: bool,
16331666
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
1634-
unbound_prompt_updates = self._get_prompt_updates(
1635-
mm_items,
1636-
hf_processor_mm_kwargs,
1637-
mm_kwargs,
1638-
)
1639-
mm_prompt_updates = self._bind_and_group_updates(
1640-
unbound_prompt_updates)
1641-
16421667
mm_item_counts = mm_items.get_all_counts()
16431668
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
16441669

@@ -1694,8 +1719,7 @@ def apply(
16941719

16951720
(
16961721
prompt_ids,
1697-
mm_kwargs,
1698-
mm_hashes,
1722+
mm_info,
16991723
is_update_applied,
17001724
) = self._cached_apply_hf_processor(
17011725
prompt,
@@ -1708,9 +1732,9 @@ def apply(
17081732
# NOTE: tokenization_kwargs are not required to init processor
17091733
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
17101734
mm_items=mm_items,
1711-
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
17121735
prompt_ids=prompt_ids,
1713-
mm_kwargs=mm_kwargs,
1736+
mm_kwargs=mm_info.kwargs,
1737+
mm_prompt_updates=mm_info.prompt_updates,
17141738
is_update_applied=is_update_applied,
17151739
)
17161740

@@ -1723,8 +1747,8 @@ def apply(
17231747
type="multimodal",
17241748
prompt=prompt,
17251749
prompt_token_ids=prompt_ids,
1726-
mm_kwargs=mm_kwargs,
1727-
mm_hashes=mm_hashes,
1750+
mm_kwargs=mm_info.kwargs,
1751+
mm_hashes=mm_info.hashes,
17281752
mm_placeholders=mm_placeholder_ranges,
17291753
)
17301754

0 commit comments

Comments
 (0)