diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index 338b208723ba9..b8818af5614cf 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -11,9 +11,10 @@ def test_skip_tokenizer_initialization(model: str): # token ids. llm = LLM(model=model, skip_tokenizer_init=True) sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) - with pytest.raises(ValueError) as err: + + with pytest.raises(ValueError, match="cannot pass text prompts when"): llm.generate("abc", sampling_params) - assert "prompts must be None if" in str(err.value) + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params) assert len(outputs) > 0 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index c0f5e2b408da9..b3e772c560863 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -4,22 +4,17 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) -from typing_extensions import assert_never - import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout -from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, - PromptComponents, SchedulerOutputState) +from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, - SingletonPromptInputs) -from vllm.inputs.parse import is_explicit_encoder_decoder_prompt +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -403,139 +398,6 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() - async def _tokenize_prompt_async( - self, - prompt: str, - request_id: str, - lora_request: Optional[LoRARequest], - ) -> List[int]: - """Async version of :meth:`_tokenize_prompt`.""" - tokenizer = self.get_tokenizer_group( - missing_msg="prompts must be None if skip_tokenizer_init is True") - - return await tokenizer.encode_async(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - - async def _extract_prompt_components_async( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: - """Async version of :meth:`_extract_prompt_components`.""" - if isinstance(inputs, str): - prompt = inputs - prompt_token_ids = await self._tokenize_prompt_async( - prompt, - request_id=request_id, - lora_request=lora_request, - ) - multi_modal_data = None - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: - prompt = None - prompt_token_ids = inputs["prompt_token_ids"] - else: - # NOTE: This extra assignment is required to pass mypy - prompt = parsed_prompt = inputs["prompt"] - prompt_token_ids = await self._tokenize_prompt_async( - parsed_prompt, - request_id=request_id, - lora_request=lora_request, - ) - - multi_modal_data = inputs.get("multi_modal_data") - else: - assert_never(inputs) - - return prompt, prompt_token_ids, multi_modal_data - - async def _process_encoder_decoder_prompt_async( - self, - inputs: PromptInputs, - request_id: str, - ) -> EncoderDecoderLLMInputs: - """Async version of :meth:`_process_encoder_decoder_prompt`.""" - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents - - if is_explicit_encoder_decoder_prompt(inputs): - encoder_task = self._extract_prompt_components_async( - inputs["encoder_prompt"], - request_id=request_id, - ) - - if (decoder_input := inputs["decoder_prompt"]) is None: - encoder_comps = await encoder_task - decoder_comps = None, None, None - else: - decoder_task = self._extract_prompt_components_async( - decoder_input, - request_id=request_id, - ) - - encoder_comps, decoder_comps = await asyncio.gather( - encoder_task, decoder_task) - else: - encoder_comps = await self._extract_prompt_components_async( - inputs, - request_id=request_id, - ) - - decoder_comps = None, None, None - - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) - - async def _process_decoder_only_prompt_async( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: - """Async version of :meth:`_process_decoder_only_prompt`.""" - prompt_comps = await self._extract_prompt_components_async( - inputs, - request_id=request_id, - lora_request=lora_request, - ) - - return self._build_decoder_only_llm_inputs( - prompt_comps, - prompt_adapter_request=prompt_adapter_request, - ) - - async def process_model_inputs_async( - self, - inputs: PromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: - """Async version of :meth:`process_model_inputs`.""" - if self.is_encoder_decoder_model(): - # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - model_inputs = await self._process_encoder_decoder_prompt_async( - inputs, - request_id=request_id, - ) - else: - if is_explicit_encoder_decoder_prompt(inputs): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") - - # Decoder-only operation - model_inputs = await self._process_decoder_only_prompt_async( - inputs, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - - return self.input_processor(model_inputs) - async def process_model_params_async( self, request_id: str, @@ -591,7 +453,7 @@ async def add_request_async( if arrival_time is None: arrival_time = time.time() - processed_inputs = await self.process_model_inputs_async( + preprocessed_inputs = await self.input_preprocessor.preprocess_async( inputs, request_id=request_id, lora_request=lora_request, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 32414f7b39972..1745dc5c09803 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,10 +6,10 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Tuple, Type, Union +from typing import Set, Type, Union import torch -from typing_extensions import TypeVar, assert_never +from typing_extensions import TypeVar import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, @@ -28,13 +28,11 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptInputs, - SingletonPromptInputs) -from vllm.inputs.parse import is_explicit_encoder_decoder_prompt + InputRegistry, LLMInputs, PromptInputs) +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams @@ -75,11 +73,6 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) -PromptComponents = Tuple[Optional[str], List[int], - Optional[MultiModalDataDict]] -DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional[MultiModalDataDict]] - @dataclass class SchedulerOutputState: @@ -313,6 +306,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.generation_config_fields = _load_generation_config_dict( model_config) + self.input_preprocessor = InputPreprocessor(model_config, + self.tokenizer) + self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( model_config) @@ -580,13 +576,12 @@ def __del__(self): def get_tokenizer_group( self, group_type: Type[_G] = BaseTokenizerGroup, - *, - missing_msg: str = MISSING_TOKENIZER_GROUP_MSG, ) -> _G: tokenizer_group = self.tokenizer if tokenizer_group is None: - raise ValueError(missing_msg) + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") if not isinstance(tokenizer_group, group_type): raise TypeError("Invalid type of tokenizer group. " f"Expected type: {group_type}, but " @@ -618,52 +613,6 @@ def _verify_args(self) -> None: self.prompt_adapter_config.verify_with_model_config( self.model_config) - def _get_bos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - if self.tokenizer is None: - logger.warning("Using None for BOS token id because tokenizer " - "is not initialized") - return None - - return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id - - def _get_eos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - if self.tokenizer is None: - logger.warning("Using None for EOS token id because tokenizer " - "is not initialized") - return None - - return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id - - def _get_decoder_start_token_id(self) -> Optional[int]: - ''' - Obtain the decoder start token id employed by an encoder/decoder - model. Returns None for non-encoder/decoder models or if the - model config is unavailable. - ''' - - if not self.is_encoder_decoder_model(): - logger.warning("Using None for decoder start token id because " - "this is not an encoder/decoder model.") - return None - - if (self.model_config is None or self.model_config.hf_config is None): - logger.warning("Using None for decoder start token id because " - "model config is not available.") - return None - - dec_start_token_id = getattr(self.model_config.hf_config, - 'decoder_start_token_id', None) - if dec_start_token_id is None: - logger.warning("Falling back on for decoder start token id " - "because decoder start token id is not available.") - dec_start_token_id = self._get_bos_token_id() - - return dec_start_token_id - def _add_processed_request( self, request_id: str, @@ -678,7 +627,7 @@ def _add_processed_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - eos_token_id = self._get_eos_token_id(lora_request) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) @@ -727,334 +676,6 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - _LLMInputComponentsType = Tuple[str, List[int]] - - def _prepare_decoder_input_ids_for_generation( - self, - decoder_input_ids: Optional[List[int]], - ) -> List[int]: - """ - Prepares `decoder_input_ids` for generation with encoder-decoder models. - - Based on - - https://github.com/huggingface/transformers/blob/ - 4037a2b5b1278736e566aec12e169100275545ea/ - src/transformers/generation/utils.py - - specifically GenerationMixin._prepare_decoder_input_ids_for_generation() - - Arguments: - - * decoder_input_ids: input token ids to preprocess - - Returns: - - * Processed token list - """ - - decoder_start_token_id = self._get_decoder_start_token_id() - assert decoder_start_token_id is not None - - if decoder_input_ids is None: - # no decoder prompt input -> - # use decoder_start_token_id as decoder_input_ids - decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - - if (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): - decoder_input_ids = [decoder_start_token_id] + decoder_input_ids - - return decoder_input_ids - - def _tokenize_prompt( - self, - prompt: str, - request_id: str, - lora_request: Optional[LoRARequest], - ) -> List[int]: - ''' - Wrapper around application of the model's tokenizer. - - Arguments: - - * prompt - * request_id - * lora_request - - Returns: - - * prompt token ids - ''' - - tokenizer = self.get_tokenizer_group( - missing_msg="prompts must be None if skip_tokenizer_init is True") - - return tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - - def _extract_prompt_components( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: - ''' - Extract the components of any single encoder or decoder input prompt. - - Arguments: - - * request_id - * inputs: single encoder or decoder input prompt - * lora_request: this is only valid for decoder prompts - - Returns: - - * prompt - * prompt_token_ids - * multi_modal_data - ''' - - if isinstance(inputs, str): - prompt = inputs - prompt_token_ids = self._tokenize_prompt( - prompt, - request_id=request_id, - lora_request=lora_request, - ) - multi_modal_data = None - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: - prompt = None - prompt_token_ids = inputs["prompt_token_ids"] - else: - # NOTE: This extra assignment is required to pass mypy - prompt = parsed_prompt = inputs["prompt"] - prompt_token_ids = self._tokenize_prompt( - parsed_prompt, - request_id=request_id, - lora_request=lora_request, - ) - - multi_modal_data = inputs.get("multi_modal_data") - else: - assert_never(inputs) - - return prompt, prompt_token_ids, multi_modal_data - - def _apply_prompt_adapter( - self, - prompt_token_ids: List[int], - prompt_adapter_request: Optional[PromptAdapterRequest], - ) -> List[int]: - if prompt_adapter_request: - prompt_token_ids = ( - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens - + prompt_token_ids) - - return prompt_token_ids - - def _get_default_enc_dec_decoder_prompt(self) -> List[int]: - ''' - Specifically for encoder/decoder models: - generate a default decoder prompt for when - the user specifies only the encoder prompt. - - Encoder/decoder models utilize the decoder - prompt in different ways; as new models are - added, it is intended that this function - will be extended to produce differing - default decoder prompts, depending on the - model variety. - - Absent a special case, the default behavior - of this method is to mirror the behavior of - the HuggingFace (HF) GenerationMixin for a None - decoder prompt, which is to employ a logit processor - setting to force the first decoded token to be . - Here, this behavior is approximated by having the - "default" decoder prompt be . - - However, it is possible that in the future - other models may have different or more - complex logic for the default decoder prompt. - This motivates having a special helper method - for default decoder prompts. - - Returns: - - * prompt_token_ids - ''' - - bos_token_id = self._get_bos_token_id() - assert bos_token_id is not None - return [bos_token_id] - - def _build_enc_dec_llm_inputs( - self, - encoder_comps: PromptComponents, - decoder_comps: DecoderPromptComponents, - ) -> EncoderDecoderLLMInputs: - encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps - - if encoder_mm_data is not None or decoder_mm_data is not None: - raise ValueError("Multi-modal encoder-decoder models are " - "not supported yet") - - decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) - - return EncoderDecoderLLMInputs( - prompt_token_ids=decoder_prompt_ids, - prompt=decoder_prompt, - encoder_prompt_token_ids=encoder_prompt_ids, - encoder_prompt=encoder_prompt, - ) - - def _process_encoder_decoder_prompt( - self, - inputs: PromptInputs, - request_id: str, - ) -> EncoderDecoderLLMInputs: - ''' - For encoder/decoder models only: - Process an input prompt into an - :class:`EncoderDecoderLLMInputs` instance. - - There are two types of input prompts: - singleton prompts which carry only the - encoder prompt, and explicit encoder/decoder - prompts which carry both the encoder and the - decoder prompts as member variables. - - This function handles the following scenarios: - * Singleton encoder prompt: extract encoder prompt - token ids & infer default decoder prompt token ids - * Explicit encoder/decoder prompt: extract encoder - and decoder prompt token ids - - Note that for Explicit encoder/decoder prompts, - each sub-prompt (encoder or decoder prompt) can - have any possible singleton type; thus this - method relies on helper functions to obtain - token ids for the sub-prompts. - - Arguments: - - * inputs: an input prompt - * request_id - - Returns: - - * :class:`EncoderDecoderLLMInputs` instance - ''' - - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents - - if is_explicit_encoder_decoder_prompt(inputs): - encoder_comps = self._extract_prompt_components( - inputs["encoder_prompt"], - request_id=request_id, - ) - - if (decoder_input := inputs["decoder_prompt"]) is None: - decoder_comps = None, None, None - else: - decoder_comps = self._extract_prompt_components( - decoder_input, - request_id=request_id, - ) - else: - encoder_comps = self._extract_prompt_components( - inputs, - request_id=request_id, - ) - - decoder_comps = None, None, None - - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) - - def _build_decoder_only_llm_inputs( - self, - prompt_comps: PromptComponents, - prompt_adapter_request: Optional[PromptAdapterRequest], - ) -> LLMInputs: - prompt, prompt_token_ids, multi_modal_data = prompt_comps - - prompt_token_ids = self._apply_prompt_adapter( - prompt_token_ids, prompt_adapter_request=prompt_adapter_request) - - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data) - - def _process_decoder_only_prompt( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: - ''' - For decoder-only models: - Process an input prompt into an :class:`LLMInputs` instance. - - Arguments: - - * inputs: input prompt - * request_id - * lora_request - * prompt_adapter_request - - Returns: - - * :class:`LLMInputs` instance - ''' - - prompt_comps = self._extract_prompt_components( - inputs, - request_id=request_id, - lora_request=lora_request, - ) - - return self._build_decoder_only_llm_inputs( - prompt_comps, - prompt_adapter_request=prompt_adapter_request, - ) - - def process_model_inputs( - self, - inputs: PromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: - - if self.is_encoder_decoder_model(): - # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - model_inputs = self._process_encoder_decoder_prompt( - inputs, - request_id=request_id, - ) - else: - if is_explicit_encoder_decoder_prompt(inputs): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") - - # Decoder-only operation - model_inputs = self._process_decoder_only_prompt( - inputs, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - - return self.input_processor(model_inputs) - def process_model_params( self, request_id: str, @@ -1151,12 +772,13 @@ def add_request( if arrival_time is None: arrival_time = time.time() - processed_inputs = self.process_model_inputs( + preprocessed_inputs = self.input_preprocessor.preprocess( inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) + processed_inputs = self.input_processor(preprocessed_inputs) processed_params = self.process_model_params(request_id=request_id, params=params, @@ -2087,7 +1709,7 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: metrics.model_execute_time) def is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model + return self.input_preprocessor.is_encoder_decoder_model() def is_embedding_model(self): return self.model_config.is_embedding_model diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index b5e8ef7860598..ac9d355c64c80 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,8 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs) + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + TokensPrompt) class ParsedText(TypedDict): @@ -60,8 +61,38 @@ def parse_and_batch_prompt( for elem in prompt ] - raise ValueError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") + raise TypeError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +class ParsedStrPrompt(TypedDict): + type: Literal["str"] + content: str + + +class ParsedTextPrompt(TypedDict): + type: Literal["text"] + content: TextPrompt + + +class ParsedTokensPrompt(TypedDict): + type: Literal["tokens"] + content: TokensPrompt + + +def parse_singleton_prompt( + inputs: SingletonPromptInputs, +) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: + if isinstance(inputs, str): + return ParsedStrPrompt(type="str", content=inputs) + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: + return ParsedTokensPrompt(type="tokens", + content=inputs) # type: ignore + elif "prompt" in inputs: + return ParsedTextPrompt(type="text", content=inputs) + + raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py new file mode 100644 index 0000000000000..be2aa5f8cb7d0 --- /dev/null +++ b/vllm/inputs/preprocess.py @@ -0,0 +1,536 @@ +import asyncio +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup + +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, + SingletonPromptInputs) +from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalDataDict + +logger = init_logger(__name__) + +PromptComponents = Tuple[Optional[str], List[int], + Optional["MultiModalDataDict"]] +DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], + Optional["MultiModalDataDict"]] + + +class InputPreprocessor: + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[BaseTokenizerGroup], + ) -> None: + super().__init__() + + self.model_config = model_config + self.tokenizer = tokenizer + + def get_tokenizer_group(self) -> BaseTokenizerGroup: + if self.tokenizer is None: + raise ValueError("You cannot pass text prompts when " + "`skip_tokenizer_init` is True") + + return self.tokenizer + + def get_bos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for BOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + + def get_eos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + + def get_decoder_start_token_id(self) -> Optional[int]: + ''' + Obtain the decoder start token id employed by an encoder/decoder + model. Returns None for non-encoder/decoder models or if the + model config is unavailable. + ''' + + if not self.is_encoder_decoder_model(): + logger.warning("Using None for decoder start token id because " + "this is not an encoder/decoder model.") + return None + + if (self.model_config is None or self.model_config.hf_config is None): + logger.warning("Using None for decoder start token id because " + "model config is not available.") + return None + + dec_start_token_id = getattr(self.model_config.hf_config, + 'decoder_start_token_id', None) + if dec_start_token_id is None: + logger.warning("Falling back on for decoder start token id " + "because decoder start token id is not available.") + dec_start_token_id = self.get_bos_token_id() + + return dec_start_token_id + + def _get_default_enc_dec_decoder_prompt(self) -> List[int]: + ''' + Specifically for encoder/decoder models: + generate a default decoder prompt for when + the user specifies only the encoder prompt. + + Encoder/decoder models utilize the decoder + prompt in different ways; as new models are + added, it is intended that this function + will be extended to produce differing + default decoder prompts, depending on the + model variety. + + Absent a special case, the default behavior + of this method is to mirror the behavior of + the HuggingFace (HF) GenerationMixin for a None + decoder prompt, which is to employ a logit processor + setting to force the first decoded token to be . + Here, this behavior is approximated by having the + "default" decoder prompt be . + + However, it is possible that in the future + other models may have different or more + complex logic for the default decoder prompt. + This motivates having a special helper method + for default decoder prompts. + + Returns: + + * prompt_token_ids + ''' + + bos_token_id = self.get_bos_token_id() + assert bos_token_id is not None + return [bos_token_id] + + def _prepare_decoder_input_ids_for_generation( + self, + decoder_input_ids: Optional[List[int]], + ) -> List[int]: + """ + Prepares `decoder_input_ids` for generation with encoder-decoder models. + + Based on + + https://github.com/huggingface/transformers/blob/ + 4037a2b5b1278736e566aec12e169100275545ea/ + src/transformers/generation/utils.py + + specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + + Arguments: + + * decoder_input_ids: input token ids to preprocess + + Returns: + + * Processed token list + """ + + decoder_start_token_id = self.get_decoder_start_token_id() + assert decoder_start_token_id is not None + + if decoder_input_ids is None: + # no decoder prompt input -> + # use decoder_start_token_id as decoder_input_ids + decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + + if (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): + decoder_input_ids = [decoder_start_token_id] + decoder_input_ids + + return decoder_input_ids + + def _apply_prompt_adapter( + self, + prompt_token_ids: List[int], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> List[int]: + if prompt_adapter_request: + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) + + return prompt_token_ids + + def _tokenize_prompt( + self, + prompt: str, + request_id: str, + lora_request: Optional[LoRARequest], + ) -> List[int]: + """ + Apply the model's tokenizer to a text prompt, returning the + corresponding token IDs. + """ + tokenizer = self.get_tokenizer_group() + + return tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + async def _tokenize_prompt_async( + self, + prompt: str, + request_id: str, + lora_request: Optional[LoRARequest], + ) -> List[int]: + """Async version of :meth:`_tokenize_prompt`.""" + tokenizer = self.get_tokenizer_group() + + return await tokenizer.encode_async(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + def _extract_prompt_components( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> PromptComponents: + ''' + Extract the components of any single encoder or decoder input prompt. + + Arguments: + + * request_id + * inputs: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + + Returns: + + * prompt + * prompt_token_ids + * multi_modal_data + ''' + + parsed = parse_singleton_prompt(inputs) + + if parsed["type"] == "str": + prompt = parsed["content"] + prompt_token_ids = self._tokenize_prompt( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = None + elif parsed["type"] == "tokens": + prompt = None + prompt_token_ids = parsed["content"]["prompt_token_ids"] + multi_modal_data = parsed["content"].get("multi_modal_data") + elif parsed["type"] == "text": + prompt = parsed["content"]["prompt"] + prompt_token_ids = self._tokenize_prompt( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = parsed["content"].get("multi_modal_data") + else: + assert_never(parsed) + + return prompt, prompt_token_ids, multi_modal_data + + async def _extract_prompt_components_async( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> PromptComponents: + """Async version of :meth:`_extract_prompt_components`.""" + parsed = parse_singleton_prompt(inputs) + + if parsed["type"] == "str": + prompt = parsed["content"] + prompt_token_ids = await self._tokenize_prompt_async( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = None + elif parsed["type"] == "tokens": + prompt = None + prompt_token_ids = parsed["content"]["prompt_token_ids"] + multi_modal_data = parsed["content"].get("multi_modal_data") + elif parsed["type"] == "text": + prompt = parsed["content"]["prompt"] + prompt_token_ids = await self._tokenize_prompt_async( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = parsed["content"].get("multi_modal_data") + else: + assert_never(parsed) + + return prompt, prompt_token_ids, multi_modal_data + + def _build_enc_dec_llm_inputs( + self, + encoder_comps: PromptComponents, + decoder_comps: DecoderPromptComponents, + ) -> EncoderDecoderLLMInputs: + encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps + + if encoder_mm_data is not None or decoder_mm_data is not None: + raise ValueError("Multi-modal encoder-decoder models are " + "not supported yet") + + decoder_prompt_ids = ( + self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + + return EncoderDecoderLLMInputs( + prompt_token_ids=decoder_prompt_ids, + prompt=decoder_prompt, + encoder_prompt_token_ids=encoder_prompt_ids, + encoder_prompt=encoder_prompt, + ) + + def _process_encoder_decoder_prompt( + self, + inputs: PromptInputs, + request_id: str, + ) -> EncoderDecoderLLMInputs: + ''' + For encoder/decoder models only: + Process an input prompt into an + :class:`EncoderDecoderLLMInputs` instance. + + There are two types of input prompts: + singleton prompts which carry only the + encoder prompt, and explicit encoder/decoder + prompts which carry both the encoder and the + decoder prompts as member variables. + + This function handles the following scenarios: + * Singleton encoder prompt: extract encoder prompt + token ids & infer default decoder prompt token ids + * Explicit encoder/decoder prompt: extract encoder + and decoder prompt token ids + + Note that for Explicit encoder/decoder prompts, + each sub-prompt (encoder or decoder prompt) can + have any possible singleton type; thus this + method relies on helper functions to obtain + token ids for the sub-prompts. + + Arguments: + + * inputs: an input prompt + * request_id + + Returns: + + * :class:`EncoderDecoderLLMInputs` instance + ''' + + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + + if is_explicit_encoder_decoder_prompt(inputs): + encoder_comps = self._extract_prompt_components( + inputs["encoder_prompt"], + request_id=request_id, + ) + + if (decoder_input := inputs["decoder_prompt"]) is None: + decoder_comps = None, None, None + else: + decoder_comps = self._extract_prompt_components( + decoder_input, + request_id=request_id, + ) + else: + encoder_comps = self._extract_prompt_components( + inputs, + request_id=request_id, + ) + + decoder_comps = None, None, None + + return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + + async def _process_encoder_decoder_prompt_async( + self, + inputs: PromptInputs, + request_id: str, + ) -> EncoderDecoderLLMInputs: + """Async version of :meth:`_process_encoder_decoder_prompt`.""" + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + + if is_explicit_encoder_decoder_prompt(inputs): + encoder_task = self._extract_prompt_components_async( + inputs["encoder_prompt"], + request_id=request_id, + ) + + if (decoder_input := inputs["decoder_prompt"]) is None: + encoder_comps = await encoder_task + decoder_comps = None, None, None + else: + decoder_task = self._extract_prompt_components_async( + decoder_input, + request_id=request_id, + ) + + encoder_comps, decoder_comps = await asyncio.gather( + encoder_task, decoder_task) + else: + encoder_comps = await self._extract_prompt_components_async( + inputs, + request_id=request_id, + ) + + decoder_comps = None, None, None + + return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + + def _build_decoder_only_llm_inputs( + self, + prompt_comps: PromptComponents, + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> LLMInputs: + prompt, prompt_token_ids, multi_modal_data = prompt_comps + + prompt_token_ids = self._apply_prompt_adapter( + prompt_token_ids, prompt_adapter_request=prompt_adapter_request) + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=prompt, + multi_modal_data=multi_modal_data) + + def _process_decoder_only_prompt( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: + ''' + For decoder-only models: + Process an input prompt into an :class:`LLMInputs` instance. + + Arguments: + + * inputs: input prompt + * request_id + * lora_request + * prompt_adapter_request + + Returns: + + * :class:`LLMInputs` instance + ''' + + prompt_comps = self._extract_prompt_components( + inputs, + request_id=request_id, + lora_request=lora_request, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + async def _process_decoder_only_prompt_async( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: + """Async version of :meth:`_process_decoder_only_prompt`.""" + prompt_comps = await self._extract_prompt_components_async( + inputs, + request_id=request_id, + lora_request=lora_request, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + def preprocess( + self, + inputs: PromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + """Preprocess the input prompt.""" + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return self._process_encoder_decoder_prompt( + inputs, + request_id=request_id, + ) + + if is_explicit_encoder_decoder_prompt(inputs): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return self._process_decoder_only_prompt( + inputs, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + async def preprocess_async( + self, + inputs: PromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + """Async version of :meth:`preprocess`.""" + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return await self._process_encoder_decoder_prompt_async( + inputs, + request_id=request_id, + ) + + if is_explicit_encoder_decoder_prompt(inputs): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return await self._process_decoder_only_prompt_async( + inputs, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + def is_encoder_decoder_model(self): + return self.model_config.is_encoder_decoder_model