Skip to content

Commit

Permalink
[Core] Factor out input preprocessing to a separate class (#7329)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Sep 13, 2024
1 parent 8f44a92 commit 5ec9c0f
Show file tree
Hide file tree
Showing 5 changed files with 590 additions and 540 deletions.
5 changes: 3 additions & 2 deletions tests/engine/test_skip_tokenizer_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ def test_skip_tokenizer_initialization(model: str):
# token ids.
llm = LLM(model=model, skip_tokenizer_init=True)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError) as err:

with pytest.raises(ValueError, match="cannot pass text prompts when"):
llm.generate("abc", sampling_params)
assert "prompts must be None if" in str(err.value)

outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
assert len(outputs) > 0
Expand Down
145 changes: 4 additions & 141 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,17 @@
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)

from typing_extensions import assert_never

import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents, SchedulerOutputState)
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -403,139 +398,6 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()

async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")

return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)

async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)

multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)

return prompt, prompt_token_ids, multi_modal_data

async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents

if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)

if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)

encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)

decoder_comps = None, None, None

return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)

async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)

return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)

async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")

# Decoder-only operation
model_inputs = await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)

return self.input_processor(model_inputs)

async def add_request_async(
self,
request_id: str,
Expand All @@ -553,12 +415,13 @@ async def add_request_async(
if arrival_time is None:
arrival_time = time.time()

processed_inputs = await self.process_model_inputs_async(
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)

self._add_processed_request(
request_id=request_id,
Expand Down
Loading

0 comments on commit 5ec9c0f

Please sign in to comment.