Skip to content

Commit 6062817

Browse files
DarkLight1337xuebwang-amd
authored andcommitted
[Optimization] Streamline InputPreprocessor (vllm-project#25702)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 6abb4cd commit 6062817

File tree

1 file changed

+12
-278
lines changed

1 file changed

+12
-278
lines changed

vllm/inputs/preprocess.py

Lines changed: 12 additions & 278 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import asyncio
54
from collections.abc import Mapping
65
from typing import Any, Optional, Union, cast
76

@@ -13,6 +12,7 @@
1312
from vllm.multimodal.cache import BaseMultiModalProcessorCache
1413
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
1514
MultiModalInputs, MultiModalUUIDDict)
15+
from vllm.multimodal.processing import BaseMultiModalProcessor
1616
from vllm.transformers_utils.tokenizer import AnyTokenizer
1717

1818
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
@@ -200,20 +200,6 @@ def _tokenize_prompt(
200200

201201
return tokenizer.encode(prompt, **tokenization_kwargs)
202202

203-
async def _tokenize_prompt_async(
204-
self,
205-
prompt: str,
206-
tokenization_kwargs: Optional[dict[str, Any]] = None,
207-
) -> list[int]:
208-
"""
209-
Async version of
210-
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
211-
"""
212-
tokenizer = self.get_tokenizer()
213-
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
214-
215-
return tokenizer.encode(prompt, **tokenization_kwargs)
216-
217203
def _get_mm_tokenizer(self) -> AnyTokenizer:
218204
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
219205
# while using also multi-modal input
@@ -223,14 +209,17 @@ def _get_mm_tokenizer(self) -> AnyTokenizer:
223209
tokenizer = self.get_tokenizer()
224210
return tokenizer
225211

226-
async def _get_mm_tokenizer_async(self) -> AnyTokenizer:
227-
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
228-
# while using also multi-modal input
229-
if not self.tokenizer:
230-
return cast(AnyTokenizer, object()) # Dummy
212+
def _get_mm_processor(self) -> BaseMultiModalProcessor:
213+
if not hasattr(self, "_mm_processor"):
214+
tokenizer = self._get_mm_tokenizer()
231215

232-
tokenizer = self.get_tokenizer()
233-
return tokenizer
216+
self._mm_processor = self.mm_registry.create_processor(
217+
self.model_config,
218+
tokenizer=tokenizer,
219+
cache=self.mm_processor_cache,
220+
)
221+
222+
return self._mm_processor
234223

235224
def _process_multimodal(
236225
self,
@@ -245,55 +234,7 @@ def _process_multimodal(
245234
Apply the model's multi-modal processor to a multi-modal prompt,
246235
returning the corresponding token IDs and metadata.
247236
"""
248-
tokenizer = self._get_mm_tokenizer()
249-
250-
mm_processor = self.mm_registry.create_processor(
251-
self.model_config,
252-
tokenizer=tokenizer,
253-
cache=self.mm_processor_cache,
254-
)
255-
256-
if mm_processor_kwargs is None:
257-
mm_processor_kwargs = {}
258-
259-
mm_input = mm_processor.apply(
260-
prompt,
261-
mm_data,
262-
hf_processor_mm_kwargs=mm_processor_kwargs,
263-
tokenization_kwargs=tokenization_kwargs,
264-
mm_uuids=mm_uuids,
265-
)
266-
mm_hashes = mm_input["mm_hashes"]
267-
268-
# Validate that all mm items have a string as their hash
269-
if not contains_only_strings(mm_hashes):
270-
raise ValueError(
271-
f"mm_hashes must contain only strings, got: {mm_hashes}. "
272-
"This is likely due to an incorrect custom implementation of "
273-
"MultiModalProcessor.apply method.")
274-
275-
return mm_input
276-
277-
async def _process_multimodal_async(
278-
self,
279-
prompt: Union[str, list[int]],
280-
mm_data: MultiModalDataDict,
281-
mm_processor_kwargs: Optional[Mapping[str, object]],
282-
tokenization_kwargs: Optional[dict[str, Any]] = None,
283-
*,
284-
mm_uuids: Optional[MultiModalUUIDDict] = None,
285-
) -> MultiModalInputs:
286-
"""
287-
Async version of
288-
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
289-
"""
290-
tokenizer = await self._get_mm_tokenizer_async()
291-
292-
mm_processor = self.mm_registry.create_processor(
293-
self.model_config,
294-
tokenizer=tokenizer,
295-
cache=self.mm_processor_cache,
296-
)
237+
mm_processor = self._get_mm_processor()
297238

298239
if mm_processor_kwargs is None:
299240
mm_processor_kwargs = {}
@@ -340,12 +281,6 @@ def _process_embeds(
340281
return embeds_inputs(prompt_embeds=prompt_embeds,
341282
cache_salt=parsed_content.get("cache_salt"))
342283

343-
async def _process_embeds_async(
344-
self,
345-
parsed_content: EmbedsPrompt,
346-
) -> EmbedsInputs:
347-
return self._process_embeds(parsed_content)
348-
349284
def _truncate_inputs(
350285
self,
351286
inputs: list[int],
@@ -389,33 +324,6 @@ def _process_tokens(
389324

390325
return inputs
391326

392-
async def _process_tokens_async(
393-
self,
394-
parsed_content: TokensPrompt,
395-
tokenization_kwargs: Optional[dict[str, Any]] = None,
396-
*,
397-
mm_uuids: Optional[MultiModalUUIDDict] = None,
398-
) -> Union[TokenInputs, MultiModalInputs]:
399-
prompt_token_ids = self._truncate_inputs(
400-
parsed_content["prompt_token_ids"], tokenization_kwargs)
401-
402-
inputs: Union[TokenInputs, MultiModalInputs]
403-
if multi_modal_data := parsed_content.get("multi_modal_data"):
404-
inputs = await self._process_multimodal_async(
405-
prompt_token_ids,
406-
multi_modal_data,
407-
parsed_content.get("mm_processor_kwargs"),
408-
tokenization_kwargs=tokenization_kwargs,
409-
mm_uuids=mm_uuids,
410-
)
411-
else:
412-
inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
413-
414-
if cache_salt := parsed_content.get("cache_salt"):
415-
inputs["cache_salt"] = cache_salt
416-
417-
return inputs
418-
419327
def _process_text(
420328
self,
421329
parsed_content: TextPrompt,
@@ -449,39 +357,6 @@ def _process_text(
449357

450358
return inputs
451359

452-
async def _process_text_async(
453-
self,
454-
parsed_content: TextPrompt,
455-
tokenization_kwargs: Optional[dict[str, Any]] = None,
456-
*,
457-
mm_uuids: Optional[MultiModalUUIDDict] = None,
458-
) -> Union[TokenInputs, MultiModalInputs]:
459-
prompt_text = parsed_content["prompt"]
460-
461-
inputs: Union[TokenInputs, MultiModalInputs]
462-
if multi_modal_data := parsed_content.get("multi_modal_data"):
463-
inputs = await self._process_multimodal_async(
464-
prompt_text,
465-
multi_modal_data,
466-
parsed_content.get("mm_processor_kwargs"),
467-
tokenization_kwargs=tokenization_kwargs,
468-
mm_uuids=mm_uuids,
469-
)
470-
else:
471-
prompt_token_ids = await self._tokenize_prompt_async(
472-
prompt_text,
473-
tokenization_kwargs=tokenization_kwargs,
474-
)
475-
inputs = token_inputs(
476-
prompt=prompt_text,
477-
prompt_token_ids=prompt_token_ids,
478-
)
479-
480-
if cache_salt := parsed_content.get("cache_salt"):
481-
inputs["cache_salt"] = cache_salt
482-
483-
return inputs
484-
485360
def _prompt_to_llm_inputs(
486361
self,
487362
prompt: SingletonPrompt,
@@ -524,41 +399,6 @@ def _prompt_to_llm_inputs(
524399

525400
assert_never(parsed)
526401

527-
async def _prompt_to_llm_inputs_async(
528-
self,
529-
prompt: SingletonPrompt,
530-
tokenization_kwargs: Optional[dict[str, Any]] = None,
531-
*,
532-
mm_uuids: Optional[MultiModalUUIDDict] = None,
533-
) -> SingletonInputs:
534-
"""
535-
Async version of
536-
[`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
537-
"""
538-
parsed = parse_singleton_prompt(prompt)
539-
540-
if parsed["type"] == "embeds":
541-
return await self._process_embeds_async(parsed["content"])
542-
if parsed["type"] == "tokens":
543-
return await self._process_tokens_async(
544-
parsed["content"],
545-
mm_uuids=mm_uuids,
546-
)
547-
if parsed["type"] == "text":
548-
return await self._process_text_async(
549-
parsed["content"],
550-
tokenization_kwargs=tokenization_kwargs,
551-
mm_uuids=mm_uuids,
552-
)
553-
if parsed["type"] == "str":
554-
return await self._process_text_async(
555-
TextPrompt(prompt=parsed["content"]),
556-
tokenization_kwargs=tokenization_kwargs,
557-
mm_uuids=mm_uuids,
558-
)
559-
560-
assert_never(parsed)
561-
562402
def _build_enc_dec_llm_inputs(
563403
self,
564404
encoder_inputs: SingletonInputs,
@@ -735,62 +575,6 @@ def _process_encoder_decoder_prompt(
735575

736576
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
737577

738-
async def _process_encoder_decoder_prompt_async(
739-
self,
740-
prompt: PromptType,
741-
tokenization_kwargs: Optional[dict[str, Any]] = None,
742-
*,
743-
mm_uuids: Optional[MultiModalUUIDDict] = None,
744-
) -> EncoderDecoderInputs:
745-
"""
746-
Async version of
747-
[`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
748-
"""
749-
encoder_inputs: SingletonInputs
750-
decoder_inputs: Optional[SingletonInputs]
751-
752-
if is_explicit_encoder_decoder_prompt(prompt):
753-
encoder_task = self._prompt_to_llm_inputs_async(
754-
prompt["encoder_prompt"],
755-
tokenization_kwargs=tokenization_kwargs,
756-
mm_uuids=mm_uuids,
757-
)
758-
759-
if (decoder_input := prompt["decoder_prompt"]) is None:
760-
encoder_inputs = await encoder_task
761-
decoder_inputs = None
762-
else:
763-
decoder_task = self._prompt_to_llm_inputs_async(
764-
decoder_input,
765-
tokenization_kwargs=tokenization_kwargs,
766-
mm_uuids=mm_uuids,
767-
)
768-
769-
encoder_inputs, decoder_inputs = await asyncio.gather(
770-
encoder_task, decoder_task)
771-
772-
# For multimodal model, override decoder prompt from processor
773-
# with explicit decoder prompt.
774-
if self.model_config.is_multimodal_model:
775-
encoder_inputs, decoder_inputs = (
776-
self._split_enc_dec_mm_inputs(encoder_inputs,
777-
decoder_inputs))
778-
else:
779-
inputs = await self._prompt_to_llm_inputs_async(
780-
prompt,
781-
tokenization_kwargs=tokenization_kwargs,
782-
mm_uuids=mm_uuids,
783-
)
784-
if self.model_config.is_multimodal_model:
785-
# Encoder-Decoder Multimodal model
786-
encoder_inputs, decoder_inputs = (
787-
self._split_enc_dec_mm_inputs(inputs))
788-
else:
789-
encoder_inputs = inputs
790-
decoder_inputs = None
791-
792-
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
793-
794578
def _build_decoder_only_llm_inputs(
795579
self,
796580
prompt_inputs: DecoderOnlyInputs,
@@ -830,25 +614,6 @@ def _process_decoder_only_prompt(
830614

831615
return self._build_decoder_only_llm_inputs(prompt_comps)
832616

833-
async def _process_decoder_only_prompt_async(
834-
self,
835-
prompt: SingletonPrompt,
836-
tokenization_kwargs: Optional[dict[str, Any]] = None,
837-
*,
838-
mm_uuids: Optional[MultiModalUUIDDict] = None,
839-
) -> DecoderOnlyInputs:
840-
"""
841-
Async version of
842-
[`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
843-
"""
844-
prompt_comps = await self._prompt_to_llm_inputs_async(
845-
prompt,
846-
tokenization_kwargs=tokenization_kwargs,
847-
mm_uuids=mm_uuids,
848-
)
849-
850-
return self._build_decoder_only_llm_inputs(prompt_comps)
851-
852617
def preprocess(
853618
self,
854619
prompt: PromptType,
@@ -877,37 +642,6 @@ def preprocess(
877642
mm_uuids=mm_uuids,
878643
)
879644

880-
async def preprocess_async(
881-
self,
882-
prompt: PromptType,
883-
tokenization_kwargs: Optional[dict[str, Any]] = None,
884-
*,
885-
mm_uuids: Optional[MultiModalUUIDDict] = None,
886-
) -> ProcessorInputs:
887-
"""
888-
Async version of
889-
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
890-
"""
891-
if self.model_config.is_encoder_decoder:
892-
# Encoder-decoder model requires special mapping of
893-
# input prompts to encoder & decoder.
894-
return await self._process_encoder_decoder_prompt_async(
895-
prompt,
896-
tokenization_kwargs,
897-
mm_uuids=mm_uuids,
898-
)
899-
900-
if is_explicit_encoder_decoder_prompt(prompt):
901-
raise ValueError("Cannot pass encoder-decoder prompt "
902-
"to decoder-only models")
903-
904-
# Decoder-only operation
905-
return await self._process_decoder_only_prompt_async(
906-
prompt,
907-
tokenization_kwargs=tokenization_kwargs,
908-
mm_uuids=mm_uuids,
909-
)
910-
911645
def clear_cache(self) -> None:
912646
if self.mm_processor_cache is not None:
913647
self.mm_processor_cache.clear_cache()

0 commit comments

Comments
 (0)