11# SPDX-License-Identifier: Apache-2.0
22
33import time
4- from collections .abc import Mapping
4+ from collections .abc import Mapping , Sequence
55from typing import Literal , Optional , Union
66
77from vllm .config import VllmConfig
1919from vllm .sampling_params import SamplingParams
2020from vllm .transformers_utils .tokenizer_group import BaseTokenizerGroup
2121from vllm .v1 .engine import EngineCoreRequest
22+ from vllm .v1 .engine .mm_input_cache import MirroredProcessingCache
2223from vllm .v1 .structured_output .backend_guidance import (
2324 validate_guidance_grammar )
2425from vllm .v1 .structured_output .utils import (
@@ -47,6 +48,8 @@ def __init__(
4748 self .tokenizer ,
4849 mm_registry )
4950
51+ self .mm_input_cache_client = MirroredProcessingCache (self .model_config )
52+
5053 # Multi-modal hasher (for images)
5154 self .use_hash = (
5255 not self .model_config .disable_mm_preprocessor_cache ) or \
@@ -231,7 +234,7 @@ def process_inputs(
231234 self .tokenizer .get_lora_tokenizer (lora_request ))
232235
233236 # Multimodal related.
234- sorted_mm_inputs : Optional [list [ MultiModalKwargs ]] = None
237+ sorted_mm_inputs : Optional [Sequence [ Optional [ MultiModalKwargs ] ]] = None
235238 sorted_mm_positions : Optional [list [PlaceholderRange ]] = None
236239 sorted_mm_hashes : Optional [list [str ]] = None
237240 if decoder_inputs ["type" ] == "multimodal" :
@@ -256,20 +259,28 @@ def process_inputs(
256259 # are multiple modalities.
257260 unique_modalities = set (sorted_item_modalities )
258261 if len (unique_modalities ) > 1 :
259- sorted_mm_inputs = []
262+ orig_sorted_mm_inputs = []
260263 used_indices = {modality : 0 for modality in unique_modalities }
264+
261265 for modality in sorted_item_modalities :
262266 items = decoder_mm_inputs .get_items (modality )
263267 item = items [used_indices [modality ]]
264- sorted_mm_inputs .append (MultiModalKwargs .from_items ([item
265- ]))
268+
269+ orig_sorted_mm_inputs .append (
270+ MultiModalKwargs .from_items ([item ]))
266271 used_indices [modality ] += 1
267272 else :
268- sorted_mm_inputs = [
273+ orig_sorted_mm_inputs = [
269274 MultiModalKwargs .from_items ([item ]) for item in
270275 decoder_mm_inputs .get_items (sorted_item_modalities [0 ])
271276 ]
272277
278+ if sorted_mm_hashes is not None :
279+ sorted_mm_inputs = self .mm_input_cache_client .get_and_update_p0 (
280+ orig_sorted_mm_inputs , sorted_mm_hashes )
281+ else :
282+ sorted_mm_inputs = orig_sorted_mm_inputs
283+
273284 return EngineCoreRequest (
274285 request_id = request_id ,
275286 prompt = decoder_inputs .get ("prompt" ),
0 commit comments