From cef68948947a69782062203af0caa8876cbab684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=B4=E6=99=AF?= Date: Thu, 2 Jan 2025 18:54:51 +0800 Subject: [PATCH 01/96] (vllm) add input embedding Signed-off-by: Andrew Sansom --- vllm/engine/async_llm_engine.py | 7 ++++ vllm/engine/llm_engine.py | 5 +++ vllm/entrypoints/llm.py | 19 +++++++++ vllm/inputs/data.py | 14 ++++++- vllm/inputs/parse.py | 3 ++ vllm/inputs/preprocess.py | 4 ++ vllm/model_executor/models/qwen2.py | 7 ++++ vllm/sequence.py | 14 ++++++- vllm/worker/model_runner.py | 60 +++++++++++++++++++++++++++++ 9 files changed, 131 insertions(+), 2 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7f9f85e1f93f..554590bb3a15 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -489,6 +489,13 @@ async def add_request_async( if arrival_time is None: arrival_time = time.time() + if isinstance(prompt, dict) and prompt.get("prompt_embeds", + None) is not None: + if not prompt.get("prompt_token_ids", None): + prompt["prompt_token_ids"] = [ + 0 + ] * prompt["prompt_embeds"].shape[0] + if self.tokenizer is not None: tokenizer = await self.get_tokenizer_async(lora_request) self._validate_token_prompt(prompt, tokenizer=tokenizer) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f842581bf551..c72e1bca3eba 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -772,6 +772,11 @@ def add_request( if arrival_time is None: arrival_time = time.time() + if isinstance(prompt, dict) and prompt.get("prompt_embeds", + None) is not None: + if not prompt.get("prompt_token_ids", None): + prompt["prompt_token_ids"] = [0] * len(prompt["prompt_embeds"]) + if self.tokenizer is not None: self._validate_token_prompt( prompt, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f39b011c9301..bd8569794927 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -9,7 +9,12 @@ import cloudpickle import torch.nn as nn from tqdm import tqdm +<<<<<<< HEAD from typing_extensions import TypeVar, deprecated +======= +from typing_extensions import deprecated +import torch +>>>>>>> 0d69ec2f ((vllm) add input embedding) from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) @@ -377,7 +382,12 @@ def generate( Optional[Union[str, list[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, +<<<<<<< HEAD prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, +======= + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, +>>>>>>> 0d69ec2f ((vllm) add input embedding) use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -442,6 +452,9 @@ def generate( parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], prompts) + if prompt_embeds is not None: + parsed_prompts.prompt_embeds = prompt_embeds + if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: raise ValueError( @@ -1233,8 +1246,14 @@ def wake_up(self, tags: Optional[list[str]] = None): # LEGACY def _convert_v1_inputs( self, +<<<<<<< HEAD prompts: Optional[Union[str, list[str]]], prompt_token_ids: Optional[Union[list[int], list[list[int]]]], +======= + prompts: Optional[Union[str, List[str]]], + prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + prompt_embeds: Optional[torch.Tensor] = None, +>>>>>>> 0d69ec2f ((vllm) add input embedding) ): # skip_tokenizer_init is now checked in engine diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 138a8f61107b..8be04b766a15 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -20,6 +20,9 @@ class TextPrompt(TypedDict): prompt: str """The input text to be tokenized before passing to the model.""" + prompt_embeds: NotRequired[torch.Tensor] + """The embeddings of the prompt, if available.""" + multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, @@ -41,6 +44,9 @@ class TokensPrompt(TypedDict): prompt_token_ids: list[int] """A list of token IDs to pass to the model.""" + prompt_embeds: NotRequired[torch.Tensor] + """The embeddings of the prompt, if available.""" + token_type_ids: NotRequired[list[int]] """A list of token type IDs to pass to the cross encoder model.""" @@ -139,6 +145,9 @@ class TokenInputs(TypedDict): prompt_token_ids: list[int] """The token IDs of the prompt.""" + prompt_embeds: NotRequired[torch.Tensor] + """The embeddings of the prompt, if available.""" + token_type_ids: NotRequired[list[int]] """The token type IDs of the prompt.""" @@ -182,6 +191,7 @@ def token_inputs( prompt_token_ids: list[int], token_type_ids: Optional[list[int]] = None, prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_inputs: Optional["MultiModalKwargs"] = None, multi_modal_hashes: Optional[list[str]] = None, @@ -195,6 +205,8 @@ def token_inputs( inputs["prompt"] = prompt if token_type_ids is not None: inputs["token_type_ids"] = token_type_ids + if prompt_embeds is not None: + inputs["prompt_embeds"] = prompt_embeds if multi_modal_data is not None: inputs["multi_modal_data"] = multi_modal_data if multi_modal_inputs is not None: @@ -277,7 +289,7 @@ def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs if inputs["type"] == "token" or inputs["type"] == "multimodal": - return None + return inputs.get("prompt_embeds") assert_never(inputs) # type: ignore[arg-type] diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 28e207de1fd3..c426f36d9041 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -96,6 +96,9 @@ def parse_singleton_prompt( elif "prompt" in prompt: return ParsedTextPrompt(type="text", content=prompt) + elif "prompt_embeds" in prompt: + return ParsedTokensPrompt(type="tokens", content=prompt) + raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 669fb96e6653..2f4dd9375e89 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -354,6 +354,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, + prompt_embeds=tokens_content.get('prompt_embeds'), token_type_ids=token_type_ids, multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, @@ -383,6 +384,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, + prompt_embeds=text_content.get('prompt_embeds'), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) @@ -428,6 +430,7 @@ async def _prompt_to_llm_inputs_async( return token_inputs( prompt_token_ids=prompt_token_ids, + prompt_embeds=tokens_content.get('prompt_embeds'), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) @@ -456,6 +459,7 @@ async def _prompt_to_llm_inputs_async( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, + prompt_embeds=text_content.get('prompt_embeds'), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index c4d02e5ddeb1..514eafac973c 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -437,6 +437,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, + True, quant_config=quant_config, prefix=maybe_prefix( prefix, "lm_head")) @@ -459,8 +460,14 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: +<<<<<<< HEAD hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) +======= + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds, self.lm_head.bias) +>>>>>>> 0d69ec2f ((vllm) add input embedding) return hidden_states def compute_logits( diff --git a/vllm/sequence.py b/vllm/sequence.py index 61867b025315..e90cd7ac6fe2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -166,6 +166,8 @@ class SequenceData(msgspec.Struct, _output_token_ids: array = msgspec.field( default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) + _prompt_embeds: Optional[torch.Tensor] = None + ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 _prompt_token_ids_tuple: tuple[int, @@ -262,7 +264,7 @@ def prompt_token_ids_array(self) -> array: @property def output_token_ids(self) -> tuple[int, ...]: return tuple(self._output_token_ids) - + @output_token_ids.setter def output_token_ids(self, new_output_token_ids: GenericSequence[int]) -> None: @@ -270,6 +272,14 @@ def output_token_ids(self, new_output_token_ids) self._update_cached_all_tokens() + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + return self._prompt_embeds + + @prompt_embeds.setter + def prompt_embeds(self, prompt_embeds: Optional[torch.Tensor]) -> None: + self._prompt_embeds = prompt_embeds + @property def output_token_ids_array(self) -> array: """Return the prompt token ids in array type. @@ -388,6 +398,7 @@ def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " f"output_token_ids={self.output_token_ids}, " + f"prompt_embeds={getattr(self.prompt_embeds, 'shape', None)}, " f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()})") @@ -426,6 +437,7 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.data = SequenceData.from_seqs(self.prompt_token_ids) + self.data.prompt_embeds = self.inputs.prompt_embeds self.output_logprobs: SampleLogprobs = [] self.output_text = "" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86e6d9752013..f302b998393d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -84,6 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase): additional fields. """ input_tokens: Optional[torch.Tensor] = None + inputs_embeds: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None @@ -104,6 +105,7 @@ class ModelInputForGPU(ModelRunnerInputBase): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, @@ -154,6 +156,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, @@ -193,6 +196,7 @@ class InterDataForSeqGroup: def simple_reinit(self): self.input_tokens[0].clear() # type: ignore + self.inputs_embeds = None # type: ignore self.input_positions[0].clear() # type: ignore self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore @@ -220,6 +224,7 @@ def __init__( # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, input_positions: Optional[List[List[int]]] = None, token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, @@ -281,6 +286,11 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.input_tokens[seq_id].clear() + if inputs_embeds is not None: + self.inputs_embeds = inputs_embeds + else: + self.inputs_embeds = None + if input_positions: self.input_positions = input_positions else: @@ -355,6 +365,8 @@ def __init__( else: self.input_tokens = input_tokens or [] + self.inputs_embeds = (inputs_embeds + if inputs_embeds is not None else None) self.input_positions = input_positions or [] self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None @@ -400,6 +412,26 @@ def __post_init__(self): self.lora_index_mapping = [] self.lora_prompt_mapping = [] + def __repr__(self) -> str: + return ( + f"InterDataForSeqGroup(" + f"request_id={self.request_id}, " + f"seq_ids={self.seq_ids}, " + f"is_prompt={self.is_prompt}, " + f"block_tables={self.block_tables}, " + f"computed_block_nums={self.computed_block_nums}, " + f"n_seqs={self.n_seqs}, " + f"input_tokens={self.input_tokens}, " + f"inputs_embeds={getattr(self.inputs_embeds, 'shape', None)}, " + f"input_positions={self.input_positions}, " + f"token_types={self.token_types}, " + f"mrope_input_positions={self.mrope_input_positions}, " + f"seq_lens={self.seq_lens}, " + f"orig_seq_lens={self.orig_seq_lens}, " + f"query_lens={self.query_lens}, " + f"context_lens={self.context_lens}, " + f"multi_modal_kwargs={self.multi_modal_kwargs}") + def gen_inter_data_builder(self, num_seqs: int): return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( request_id="", @@ -512,12 +544,19 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] + if seq_data.prompt_embeds is not None and seq_data.get_output_len( + ) == 0: + prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] + else: + seq_data.prompt_embeds = None + prompt_embeds = None token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) + inter_data.inputs_embeds = prompt_embeds inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.token_types[seq_idx].extend( token_types if token_types else []) @@ -823,12 +862,23 @@ def build(self) -> ModelInputForGPU: """ # Combine and flatten intermediate data. input_tokens = [] + inputs_embeds = [] token_types = [] for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) + if inter_data.inputs_embeds is not None: + inputs_embeds.append( + inter_data.inputs_embeds.to(self.runner.device)) + + if len(inputs_embeds) == 0: + inputs_embeds = None + elif len(inputs_embeds) == 1: + inputs_embeds = inputs_embeds[0] + else: + inputs_embeds = torch.cat(inputs_embeds, dim=0) if not input_tokens: # This may happen when all prefill requests hit @@ -980,6 +1030,7 @@ def build(self) -> ModelInputForGPU: return self.model_input_cls( input_tokens=input_tokens_tensor, + inputs_embeds=inputs_embeds, input_positions=input_positions_tensor, token_types=token_types_tensor, attn_metadata=attn_metadata, @@ -1769,6 +1820,9 @@ def execute_model( self.vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, + **{ + "inputs_embeds": model_input.inputs_embeds, + } if model_input.inputs_embeds is not None else {}, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, @@ -1883,6 +1937,9 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None + if self.vllm_config.kv_transfer_config is None: + return False + return self.vllm_config.kv_transfer_config.is_kv_consumer and ( not is_profile_run) and is_prefill_run @@ -1908,6 +1965,9 @@ def need_send_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None + if self.vllm_config.kv_transfer_config is None: + return False + return self.vllm_config.kv_transfer_config.is_kv_producer and ( not is_profile_run) and is_prefill_run From c51d8fb518354a0ca93c85b2f2013c06a1503f53 Mon Sep 17 00:00:00 2001 From: Bryce1010 Date: Mon, 6 Jan 2025 19:04:34 +0800 Subject: [PATCH 02/96] improve embedding input Signed-off-by: Andrew Sansom --- vllm/attention/backends/flash_attn.py | 1 - vllm/engine/llm_engine.py | 5 ++--- vllm/entrypoints/llm.py | 14 -------------- vllm/inputs/data.py | 3 +++ vllm/inputs/preprocess.py | 8 ++++---- vllm/model_executor/models/qwen2.py | 6 +----- vllm/sequence.py | 18 +++++++++--------- vllm/worker/model_runner.py | 17 ++++++++--------- 8 files changed, 27 insertions(+), 45 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 27bd292b51f2..3b5e54980195 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -769,7 +769,6 @@ def forward( prefill_output = output[:num_prefill_query_tokens] assert query.shape[0] == num_prefill_query_tokens assert decode_query.shape[0] == num_decode_query_tokens - if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache.numel() == 0 or prefill_meta.block_tables is None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c72e1bca3eba..f476d2b08b0a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -772,10 +772,9 @@ def add_request( if arrival_time is None: arrival_time = time.time() - if isinstance(prompt, dict) and prompt.get("prompt_embeds", - None) is not None: + if isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None: if not prompt.get("prompt_token_ids", None): - prompt["prompt_token_ids"] = [0] * len(prompt["prompt_embeds"]) + prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0] if self.tokenizer is not None: self._validate_token_prompt( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index bd8569794927..ff00e79bd8e3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -9,12 +9,7 @@ import cloudpickle import torch.nn as nn from tqdm import tqdm -<<<<<<< HEAD from typing_extensions import TypeVar, deprecated -======= -from typing_extensions import deprecated -import torch ->>>>>>> 0d69ec2f ((vllm) add input embedding) from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) @@ -382,12 +377,8 @@ def generate( Optional[Union[str, list[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, -<<<<<<< HEAD - prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, -======= prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, prompt_embeds: Optional[torch.Tensor] = None, ->>>>>>> 0d69ec2f ((vllm) add input embedding) use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -1246,14 +1237,9 @@ def wake_up(self, tags: Optional[list[str]] = None): # LEGACY def _convert_v1_inputs( self, -<<<<<<< HEAD - prompts: Optional[Union[str, list[str]]], - prompt_token_ids: Optional[Union[list[int], list[list[int]]]], -======= prompts: Optional[Union[str, List[str]]], prompt_token_ids: Optional[Union[List[int], List[List[int]]]], prompt_embeds: Optional[torch.Tensor] = None, ->>>>>>> 0d69ec2f ((vllm) add input embedding) ): # skip_tokenizer_init is now checked in engine diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 8be04b766a15..aea9b77b44b7 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -156,6 +156,9 @@ class TokenInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ + prompt_embeds: NotRequired[torch.Tensor] + """The embeddings of the prompt, if available.""" + multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 2f4dd9375e89..f77880f97b0d 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -354,7 +354,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, - prompt_embeds=tokens_content.get('prompt_embeds'), + prompt_embeds=tokens_content.get("prompt_embeds"), token_type_ids=token_type_ids, multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, @@ -384,7 +384,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - prompt_embeds=text_content.get('prompt_embeds'), + prompt_embeds=text_content.get("prompt_embeds"), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) @@ -430,7 +430,7 @@ async def _prompt_to_llm_inputs_async( return token_inputs( prompt_token_ids=prompt_token_ids, - prompt_embeds=tokens_content.get('prompt_embeds'), + prompt_embeds=tokens_content.get("prompt_embeds"), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) @@ -459,7 +459,7 @@ async def _prompt_to_llm_inputs_async( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - prompt_embeds=text_content.get('prompt_embeds'), + prompt_embeds=tokens_content.get("prompt_embeds"), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 514eafac973c..8671fdc28276 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -460,14 +460,10 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: -<<<<<<< HEAD - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) -======= + hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds, self.lm_head.bias) ->>>>>>> 0d69ec2f ((vllm) add input embedding) return hidden_states def compute_logits( diff --git a/vllm/sequence.py b/vllm/sequence.py index e90cd7ac6fe2..ecbfb899075b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -272,14 +272,6 @@ def output_token_ids(self, new_output_token_ids) self._update_cached_all_tokens() - @property - def prompt_embeds(self) -> Optional[torch.Tensor]: - return self._prompt_embeds - - @prompt_embeds.setter - def prompt_embeds(self, prompt_embeds: Optional[torch.Tensor]) -> None: - self._prompt_embeds = prompt_embeds - @property def output_token_ids_array(self) -> array: """Return the prompt token ids in array type. @@ -289,6 +281,14 @@ def output_token_ids_array(self) -> array: """ assert isinstance(self._output_token_ids, array) return self._output_token_ids + + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + return self._prompt_embeds + + @prompt_embeds.setter + def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None: + self._prompt_embeds = prompt_embeds @property def mrope_position_delta(self) -> Optional[int]: @@ -397,8 +397,8 @@ def stage(self) -> SequenceStage: def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " + f"prompt_embeds={getattr(self._prompt_embeds, 'shape', None)}, " f"output_token_ids={self.output_token_ids}, " - f"prompt_embeds={getattr(self.prompt_embeds, 'shape', None)}, " f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()})") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f302b998393d..0f44446ca888 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -365,8 +365,9 @@ def __init__( else: self.input_tokens = input_tokens or [] - self.inputs_embeds = (inputs_embeds - if inputs_embeds is not None else None) + self.inputs_embeds = ( + inputs_embeds if inputs_embeds is not None else None + ) self.input_positions = input_positions or [] self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None @@ -544,12 +545,12 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] - if seq_data.prompt_embeds is not None and seq_data.get_output_len( - ) == 0: - prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] + if seq_data.prompt_embeds is not None and seq_data.get_output_len() == 0: + prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] else: - seq_data.prompt_embeds = None + seq_data.prompt_embeds = None # release memory prompt_embeds = None + token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len @@ -870,9 +871,7 @@ def build(self) -> ModelInputForGPU: for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: - inputs_embeds.append( - inter_data.inputs_embeds.to(self.runner.device)) - + inputs_embeds.append(inter_data.inputs_embeds.to(self.runner.device)) if len(inputs_embeds) == 0: inputs_embeds = None elif len(inputs_embeds) == 1: From 9564b4023fab034e7ea16fbaefffcaaff69c01a2 Mon Sep 17 00:00:00 2001 From: Bryce1010 Date: Thu, 6 Mar 2025 11:02:23 +0800 Subject: [PATCH 03/96] (vllm) fix import error Signed-off-by: Andrew Sansom --- vllm/entrypoints/llm.py | 39 +++++++++++++++++++++++------ vllm/model_executor/models/qwen2.py | 4 +-- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ff00e79bd8e3..315c2aaef4f8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -10,7 +10,7 @@ import torch.nn as nn from tqdm import tqdm from typing_extensions import TypeVar, deprecated - +import torch from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) from vllm.config import CompilationConfig @@ -282,6 +282,8 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -298,6 +300,8 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -314,6 +318,8 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -331,6 +337,8 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -348,6 +356,8 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -363,6 +373,8 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -377,7 +389,7 @@ def generate( Optional[Union[str, list[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, - prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, @@ -401,10 +413,15 @@ def generate( When it is a single value, it is applied to every prompt. When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. + prompt_token_ids: DEPRECATED. Token IDs for the prompts. If provided, + the `prompts` will be ignored. + prompt_embeds: Optional tensor of prompt embeddings to use instead of + text prompts. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for generation, if any. + guided_options_request: Options for guided decoding, if any. priority: The priority of the requests, if any. Only applicable when priority scheduling policy is enabled. @@ -438,13 +455,13 @@ def generate( parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, list[str]]], prompts), prompt_token_ids=prompt_token_ids, + prompt_embeds=prompt_embeds, ) else: parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], prompts) - - if prompt_embeds is not None: - parsed_prompts.prompt_embeds = prompt_embeds + if prompt_embeds is not None and hasattr(parsed_prompts, "prompt_embeds"): + parsed_prompts.prompt_embeds = prompt_embeds if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -1237,8 +1254,8 @@ def wake_up(self, tags: Optional[list[str]] = None): # LEGACY def _convert_v1_inputs( self, - prompts: Optional[Union[str, List[str]]], - prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + prompts: Optional[Union[str, list[str]]], + prompt_token_ids: Optional[Union[list[int], list[list[int]]]], prompt_embeds: Optional[torch.Tensor] = None, ): # skip_tokenizer_init is now checked in engine @@ -1277,6 +1294,13 @@ def _convert_v1_inputs( parsed_prompts.append(item) + # Handle prompt_embeds if provided + if prompt_embeds is not None: + # Assuming prompt_embeds is a tensor that can be assigned to the first prompt + # This might need adjustment based on how prompt_embeds is actually used + if len(parsed_prompts) > 0 and hasattr(parsed_prompts[0], "prompt_embeds"): + parsed_prompts[0].prompt_embeds = prompt_embeds + return parsed_prompts def _validate_and_add_requests( @@ -1414,3 +1438,4 @@ def _run_engine( # This is necessary because some requests may be finished earlier than # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id)) + diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8671fdc28276..00c1fb8f28ab 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -461,9 +461,7 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds, self.lm_head.bias) + hidden_states = self.model(input_ids, positions,intermediate_tensors, inputs_embeds, self.lm_head.bias) return hidden_states def compute_logits( From c60298a48faa6dd26a88616f15244aff252d9d5b Mon Sep 17 00:00:00 2001 From: Bryce1010 Date: Thu, 6 Mar 2025 15:30:38 +0800 Subject: [PATCH 04/96] (vllm) fix pre commit error Signed-off-by: Andrew Sansom --- vllm/entrypoints/llm.py | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 315c2aaef4f8..f59119317264 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -282,8 +282,6 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -300,8 +298,6 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -318,8 +314,6 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -337,8 +331,6 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -356,8 +348,6 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -373,8 +363,6 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - priority: Optional[list[int]] = None, ) -> list[RequestOutput]: ... @@ -455,13 +443,15 @@ def generate( parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, list[str]]], prompts), prompt_token_ids=prompt_token_ids, - prompt_embeds=prompt_embeds, ) else: parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], prompts) - if prompt_embeds is not None and hasattr(parsed_prompts, "prompt_embeds"): - parsed_prompts.prompt_embeds = prompt_embeds + + # Handle prompt_embeds separately + # This is a simplified approach - you may need to adjust based on how prompt_embeds is used + if prompt_embeds is not None: + parsed_prompts.prompt_embeds = prompt_embeds if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -1294,15 +1284,10 @@ def _convert_v1_inputs( parsed_prompts.append(item) - # Handle prompt_embeds if provided - if prompt_embeds is not None: - # Assuming prompt_embeds is a tensor that can be assigned to the first prompt - # This might need adjustment based on how prompt_embeds is actually used - if len(parsed_prompts) > 0 and hasattr(parsed_prompts[0], "prompt_embeds"): - parsed_prompts[0].prompt_embeds = prompt_embeds - + # We don't need to handle prompt_embeds here since it's handled in the generate method return parsed_prompts +<<<<<<< HEAD def _validate_and_add_requests( self, prompts: Union[PromptType, Sequence[PromptType]], @@ -1439,3 +1424,5 @@ def _run_engine( # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id)) +======= +>>>>>>> def7f49d ((vllm) fix pre commit error) From 0c24a823921e9e1636f506a22cff7fae439c8091 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 24 Mar 2025 20:39:50 -0500 Subject: [PATCH 05/96] apply ruff and isort fixes Signed-off-by: Andrew Sansom --- vllm/engine/llm_engine.py | 7 ++++--- vllm/entrypoints/llm.py | 24 ++++++++++-------------- vllm/model_executor/models/qwen2.py | 3 ++- vllm/sequence.py | 19 ++++++++++--------- vllm/worker/model_runner.py | 13 +++++++------ 5 files changed, 33 insertions(+), 33 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f476d2b08b0a..a02bfd28275b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -772,9 +772,10 @@ def add_request( if arrival_time is None: arrival_time = time.time() - if isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None: - if not prompt.get("prompt_token_ids", None): - prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0] + if (isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None)): + prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0] if self.tokenizer is not None: self._validate_token_prompt( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f59119317264..d8e65b248c9b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,14 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import itertools -import warnings from collections.abc import Sequence from contextlib import contextmanager from typing import Any, Callable, ClassVar, Optional, Union, cast, overload import cloudpickle +import torch import torch.nn as nn -from tqdm import tqdm from typing_extensions import TypeVar, deprecated import torch from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, @@ -36,8 +35,7 @@ ScoringRequestOutput) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, - RequestOutputKind, SamplingParams) +from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -401,10 +399,10 @@ def generate( When it is a single value, it is applied to every prompt. When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. - prompt_token_ids: DEPRECATED. Token IDs for the prompts. If provided, - the `prompts` will be ignored. - prompt_embeds: Optional tensor of prompt embeddings to use instead of - text prompts. + prompt_token_ids: DEPRECATED. Token IDs for the prompts. If + provided, the `prompts` will be ignored. + prompt_embeds: Optional tensor of prompt embeddings to use instead + of text prompts. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for @@ -449,7 +447,8 @@ def generate( prompts) # Handle prompt_embeds separately - # This is a simplified approach - you may need to adjust based on how prompt_embeds is used + # This is a simplified approach - you may need to adjust based on how + # prompt_embeds is used if prompt_embeds is not None: parsed_prompts.prompt_embeds = prompt_embeds @@ -1284,10 +1283,10 @@ def _convert_v1_inputs( parsed_prompts.append(item) - # We don't need to handle prompt_embeds here since it's handled in the generate method + # We don't need to handle prompt_embeds here since it's handled in the + # generate method return parsed_prompts -<<<<<<< HEAD def _validate_and_add_requests( self, prompts: Union[PromptType, Sequence[PromptType]], @@ -1423,6 +1422,3 @@ def _run_engine( # This is necessary because some requests may be finished earlier than # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id)) - -======= ->>>>>>> def7f49d ((vllm) fix pre commit error) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 00c1fb8f28ab..a6750f2025cf 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -461,7 +461,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions,intermediate_tensors, inputs_embeds, self.lm_head.bias) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, self.lm_head.bias) return hidden_states def compute_logits( diff --git a/vllm/sequence.py b/vllm/sequence.py index ecbfb899075b..d5549a02a113 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -264,7 +264,7 @@ def prompt_token_ids_array(self) -> array: @property def output_token_ids(self) -> tuple[int, ...]: return tuple(self._output_token_ids) - + @output_token_ids.setter def output_token_ids(self, new_output_token_ids: GenericSequence[int]) -> None: @@ -281,11 +281,11 @@ def output_token_ids_array(self) -> array: """ assert isinstance(self._output_token_ids, array) return self._output_token_ids - + @property def prompt_embeds(self) -> Optional[torch.Tensor]: return self._prompt_embeds - + @prompt_embeds.setter def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None: self._prompt_embeds = prompt_embeds @@ -395,12 +395,13 @@ def stage(self) -> SequenceStage: return self._stage def __repr__(self) -> str: - return (f"SequenceData(" - f"prompt_token_ids={self._prompt_token_ids}, " - f"prompt_embeds={getattr(self._prompt_embeds, 'shape', None)}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"get_num_computed_tokens={self.get_num_computed_tokens()})") + return ( + f"SequenceData(" + f"prompt_token_ids={self._prompt_token_ids}, " + f"prompt_embeds={getattr(self._prompt_embeds, 'shape', None)}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"get_num_computed_tokens={self.get_num_computed_tokens()})") class Sequence: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0f44446ca888..5ecde479b7a2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -365,9 +365,8 @@ def __init__( else: self.input_tokens = input_tokens or [] - self.inputs_embeds = ( - inputs_embeds if inputs_embeds is not None else None - ) + self.inputs_embeds = (inputs_embeds + if inputs_embeds is not None else None) self.input_positions = input_positions or [] self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None @@ -545,8 +544,9 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] - if seq_data.prompt_embeds is not None and seq_data.get_output_len() == 0: - prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] + if seq_data.prompt_embeds is not None and seq_data.get_output_len( + ) == 0: + prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] else: seq_data.prompt_embeds = None # release memory prompt_embeds = None @@ -871,7 +871,8 @@ def build(self) -> ModelInputForGPU: for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: - inputs_embeds.append(inter_data.inputs_embeds.to(self.runner.device)) + inputs_embeds.append( + inter_data.inputs_embeds.to(self.runner.device)) if len(inputs_embeds) == 0: inputs_embeds = None elif len(inputs_embeds) == 1: From 403a16550a0951ac0f236a6b77bcfdd29ca0db00 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 24 Mar 2025 20:55:40 -0500 Subject: [PATCH 06/96] apply ruff and isort fixes Signed-off-by: Andrew Sansom --- vllm/engine/async_llm_engine.py | 12 ++++++------ vllm/entrypoints/llm.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 554590bb3a15..447dee4c4373 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -489,12 +489,12 @@ async def add_request_async( if arrival_time is None: arrival_time = time.time() - if isinstance(prompt, dict) and prompt.get("prompt_embeds", - None) is not None: - if not prompt.get("prompt_token_ids", None): - prompt["prompt_token_ids"] = [ - 0 - ] * prompt["prompt_embeds"].shape[0] + if ( + isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None) + ): + prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0] if self.tokenizer is not None: tokenizer = await self.get_tokenizer_async(lora_request) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d8e65b248c9b..d5a071158866 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn from typing_extensions import TypeVar, deprecated -import torch + from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) from vllm.config import CompilationConfig From b1ac0721836e800932275a08d4afee091b7b7c56 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 24 Mar 2025 21:04:47 -0500 Subject: [PATCH 07/96] styling Signed-off-by: Andrew Sansom --- vllm/engine/async_llm_engine.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 447dee4c4373..cc85310c3a1f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -489,11 +489,9 @@ async def add_request_async( if arrival_time is None: arrival_time = time.time() - if ( - isinstance(prompt, dict) - and prompt.get("prompt_embeds", None) is not None - and not prompt.get("prompt_token_ids", None) - ): + if (isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None)): prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0] if self.tokenizer is not None: From 0390c3342a6ebb59683ff53ae5032669dfac4fe1 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 24 Mar 2025 21:14:43 -0500 Subject: [PATCH 08/96] fix missing imports from rebase Signed-off-by: Andrew Sansom --- vllm/entrypoints/llm.py | 5 ++++- vllm/inputs/data.py | 3 --- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d5a071158866..6b14522c42e5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import itertools +import warnings from collections.abc import Sequence from contextlib import contextmanager from typing import Any, Callable, ClassVar, Optional, Union, cast, overload @@ -8,6 +9,7 @@ import cloudpickle import torch import torch.nn as nn +from tqdm import tqdm from typing_extensions import TypeVar, deprecated from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, @@ -35,7 +37,8 @@ ScoringRequestOutput) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index aea9b77b44b7..8be04b766a15 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -156,9 +156,6 @@ class TokenInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ - prompt_embeds: NotRequired[torch.Tensor] - """The embeddings of the prompt, if available.""" - multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, From 0ca4daec8eaebfd4a61720b22364fd0abed7b177 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 24 Mar 2025 21:33:43 -0500 Subject: [PATCH 09/96] typing fixes Signed-off-by: Andrew Sansom --- vllm/entrypoints/llm.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 6b14522c42e5..272942ecd817 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -278,6 +278,7 @@ def generate( sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, *, + prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -294,6 +295,7 @@ def generate( sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None, prompt_token_ids: Optional[list[int]] = None, + prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -310,6 +312,7 @@ def generate( sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None, prompt_token_ids: Optional[list[list[int]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -327,6 +330,7 @@ def generate( list[SamplingParams]]] = None, *, prompt_token_ids: list[int], + prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -344,6 +348,7 @@ def generate( list[SamplingParams]]] = None, *, prompt_token_ids: list[list[int]], + prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -359,6 +364,7 @@ def generate( prompts: None, sampling_params: None, prompt_token_ids: Union[list[int], list[list[int]]], + prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -1368,8 +1374,12 @@ def _add_guided_params( raise ValueError("Cannot set both guided_options_request and " "params.guided_decoding.") + json = (guided_options.guided_json.model_dump_json() if + (guided_options.guided_json is not None + and not isinstance(guided_options.guided_json, (dict, str))) + else guided_options.guided_json) params.guided_decoding = GuidedDecodingParams( - json=guided_options.guided_json, + json=json, regex=guided_options.guided_regex, choice=guided_options.guided_choice, grammar=guided_options.guided_grammar, From 35320fe4a3a10f7eecebe2e4222a6bdba69f2b84 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 24 Mar 2025 21:47:40 -0500 Subject: [PATCH 10/96] type fix Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5ecde479b7a2..9133f0ee4ef8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -863,7 +863,7 @@ def build(self) -> ModelInputForGPU: """ # Combine and flatten intermediate data. input_tokens = [] - inputs_embeds = [] + inputs_embeds: Union[list[torch.Tensor], None] = [] token_types = [] for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: From 0a77630caa569e919e3e1909a2cd0ebb51801696 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 24 Mar 2025 22:01:19 -0500 Subject: [PATCH 11/96] type fix Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9133f0ee4ef8..9d9781459b1b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -863,7 +863,7 @@ def build(self) -> ModelInputForGPU: """ # Combine and flatten intermediate data. input_tokens = [] - inputs_embeds: Union[list[torch.Tensor], None] = [] + inputs_embeds_: list[torch.Tensor] = [] token_types = [] for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: @@ -871,14 +871,15 @@ def build(self) -> ModelInputForGPU: for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: - inputs_embeds.append( + inputs_embeds_.append( inter_data.inputs_embeds.to(self.runner.device)) - if len(inputs_embeds) == 0: + inputs_embeds: Optional[torch.Tensor] + if len(inputs_embeds_) == 0: inputs_embeds = None - elif len(inputs_embeds) == 1: - inputs_embeds = inputs_embeds[0] + elif len(inputs_embeds_) == 1: + inputs_embeds = inputs_embeds_[0] else: - inputs_embeds = torch.cat(inputs_embeds, dim=0) + inputs_embeds = torch.cat(inputs_embeds_, dim=0) if not input_tokens: # This may happen when all prefill requests hit From 11b6c02a866a97fd59d18417263df16159f9e3ce Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 25 Mar 2025 09:35:40 -0500 Subject: [PATCH 12/96] remove unnecessary changes Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9d9781459b1b..2f4fd74223fb 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1928,9 +1928,6 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: kv_caches: vLLM's paged memory """ - if self.vllm_config.kv_transfer_config is None: - return False - prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling @@ -1938,9 +1935,6 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None - if self.vllm_config.kv_transfer_config is None: - return False - return self.vllm_config.kv_transfer_config.is_kv_consumer and ( not is_profile_run) and is_prefill_run From cb92a3ddceef771a2062da8560ae33e7b5b34d30 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 25 Mar 2025 09:37:04 -0500 Subject: [PATCH 13/96] remove unnecessary changes Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2f4fd74223fb..b26b6c3cdac7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1928,6 +1928,9 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: kv_caches: vLLM's paged memory """ + if self.vllm_config.kv_transfer_config is None: + return False + prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling @@ -1960,9 +1963,6 @@ def need_send_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None - if self.vllm_config.kv_transfer_config is None: - return False - return self.vllm_config.kv_transfer_config.is_kv_producer and ( not is_profile_run) and is_prefill_run From 375bd5b4e554c6e84ad88da87d951900d41a6a21 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 25 Mar 2025 13:33:37 -0500 Subject: [PATCH 14/96] re-add deleted whitespace Signed-off-by: Andrew Sansom --- vllm/attention/backends/flash_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 3b5e54980195..27bd292b51f2 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -769,6 +769,7 @@ def forward( prefill_output = output[:num_prefill_query_tokens] assert query.shape[0] == num_prefill_query_tokens assert decode_query.shape[0] == num_decode_query_tokens + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache.numel() == 0 or prefill_meta.block_tables is None From c9d8024831ac7d20df42ea078e30ce7c0e2559b4 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 25 Mar 2025 17:04:56 -0500 Subject: [PATCH 15/96] Include unit tests from #6869. Co-authored-by: Nan2018 Signed-off-by: Andrew Sansom --- tests/conftest.py | 17 ++-- .../decoder_only/language/test_models.py | 19 ++++ tests/worker/test_model_runner.py | 92 +++++++++++++++---- vllm/sequence.py | 11 ++- vllm/worker/model_runner.py | 16 ++-- 5 files changed, 116 insertions(+), 39 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b833cff4db7c..83d5bfecb42e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -742,7 +742,7 @@ def __init__( def get_inputs( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, @@ -764,16 +764,19 @@ def get_inputs( if audios is not None and (audio := audios[i]) is not None: multi_modal_data["audio"] = audio + text_prompt_kwargs = { + ("prompt" if isinstance(prompt, str) else "prompt_embeds"): prompt, + "multi_modal_data": multi_modal_data or None + } inputs.append( - TextPrompt(prompt=prompt, - multi_modal_data=multi_modal_data - if multi_modal_data else None)) + TextPrompt(**text_prompt_kwargs) + ) return inputs def generate( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, @@ -799,7 +802,7 @@ def generate( output_str = sample.text output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) - req_sample_output_strs.append(prompt_str + output_str) + req_sample_output_strs.append(prompt_str or "" + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs @@ -866,7 +869,7 @@ def generate_encoder_decoder_w_logprobs( def generate_greedy( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 79fa3fa99773..e7fdbb26b546 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -113,9 +113,21 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) + prompt_embeds = [] + prompt_token_ids = [] + for prompt in example_prompts: + token_ids = hf_model.tokenizer(prompt, + return_tensors="pt").input_ids.to( + hf_model.model.device) + prompt_token_ids.append(token_ids) + prompt_embeds.append( + hf_model.model.get_input_embeddings()(token_ids).squeeze(0)) + with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( + prompt_embeds, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -123,6 +135,13 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, name_0="hf", name_1="vllm", ) + check_logprobs_close( + outputs_0_lst=vllm_outputs, + outputs_1_lst=vllm_outputs_from_embeds, + name_0="vllm", + name_1="vllm_from_embeds", + ) + if use_rocm_aiter: # this is to ensure that vllm engine # has deallocated the memory before running the next diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index b8ba69b0dd8f..c8bd2a3d287c 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import random + import pytest import torch @@ -31,8 +33,9 @@ def test_deepseek_mla_attn_backend_module(): assert model_runner.attn_backend.__name__ == "TritonMLABackend" -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_prompt(batch_size): +@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) +@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) +def test_prepare_prompt(batch_size, prompt_embeds_ratio): model_runner = _create_model_runner( "facebook/opt-125m", max_num_batched_tokens=100000, @@ -43,11 +46,20 @@ def test_prepare_prompt(batch_size): seq_lens: list[int] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = [] block_tables = {0: [1]} + input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData.from_seqs( + range(seq_len), + prompt_embeds=torch.rand(seq_len, 10), + ) + input_embeds_len += seq_len + else: + seq_data = SequenceData.from_seqs(range(seq_len)) + seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -68,6 +80,7 @@ def test_prepare_prompt(batch_size): seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens slot_mapping = attn_metadata.slot_mapping @@ -121,7 +134,11 @@ def test_prepare_prompt(batch_size): assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) - torch.testing.assert_close(input_tokens, input_positions) + if input_embeds_len == 0: + torch.testing.assert_close(input_tokens, input_positions) + assert input_embeds is None + else: + assert len(input_embeds) == input_embeds_len sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, @@ -145,8 +162,9 @@ def test_prepare_prompt(batch_size): torch.testing.assert_close(actual, expected) -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_decode_cuda_graph(batch_size): +@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) +@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) +def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -160,11 +178,19 @@ def test_prepare_decode_cuda_graph(batch_size): context_lens: list[int] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = [] # Assume each seq group finishes prefill. + input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) - seq_data = SequenceData.from_seqs(range(context_len)) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData.from_seqs( + [], + prompt_embeds=torch.rand(context_len, 10), + ) + input_embeds_len += context_len + else: + seq_data = SequenceData.from_seqs(range(context_len)) seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. seq_data.append_token_id(1, 0) @@ -180,9 +206,10 @@ def test_prepare_decode_cuda_graph(batch_size): model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata, slot_mapping = ( + input_tokens, input_positions, input_embeds, attn_metadata, slot_mapping = ( model_input.input_tokens, model_input.input_positions, - model_input.attn_metadata, model_input.attn_metadata.slot_mapping) + model_input.inputs_embeds, model_input.attn_metadata, + model_input.attn_metadata.slot_mapping) assert len(slot_mapping) == len(input_tokens) expected_bs = model_runner.vllm_config.pad_for_cudagraph( @@ -236,6 +263,7 @@ def test_prepare_decode_cuda_graph(batch_size): assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs torch.allclose(input_tokens, input_positions) + assert input_embeds is None # Verify Sampling expected_selected_token_indices = [] @@ -277,14 +305,17 @@ def test_empty_seq_group(): model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata, return_seq_lens) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - model_input.seq_lens, - ) + (input_tokens, input_positions, input_embeds, attn_metadata, + return_seq_lens) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.inputs_embeds, + model_input.attn_metadata, + model_input.seq_lens, + ) assert input_tokens is None assert input_positions is None + assert input_embeds is None assert attn_metadata is None assert return_seq_lens is None @@ -299,9 +330,11 @@ def distributed_init(): ensure_model_parallel_initialized(1, 1) -@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("batch_size", list(range(2, 128, 3))) @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_hybrid_batches(batch_size, enforce_eager, distributed_init): +@pytest.mark.parametrize('prompt_embeds_ratio', [0.0, 0.5, 1.0]) +def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, + distributed_init): model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -320,11 +353,19 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): block_tables = {0: [1]} prefill_batch_size = batch_size // 2 decode_batch_size = batch_size - prefill_batch_size + input_embeds_len = 0 for i in range(prefill_batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData.from_seqs( + [], + prompt_embeds=torch.rand(seq_len, 10), + ) + input_embeds_len += seq_len + else: + seq_data = SequenceData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -340,7 +381,13 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData.from_seqs(range(context_len)) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData.from_seqs( + [], + prompt_embeds=torch.rand(context_len, 10), + ) + else: + seq_data = SequenceData.from_seqs(range(context_len)) seq_data.append_token_id(1, 0) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( @@ -355,9 +402,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata) = ( + (input_tokens, input_positions, input_embeds, attn_metadata) = ( model_input.input_tokens, model_input.input_positions, + model_input.inputs_embeds, model_input.attn_metadata, ) @@ -369,6 +417,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert attn_metadata.num_prefills == prefill_batch_size assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_prefill_tokens == sum(seq_lens) + if input_embeds_len == 0: + assert input_embeds is None + else: + assert len(input_embeds) == input_embeds_len # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/vllm/sequence.py b/vllm/sequence.py index d5549a02a113..666a1806720d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -210,6 +210,8 @@ def from_prompt_token_counts( def from_seqs( prompt_token_ids: GenericSequence[int], output_token_ids: Optional[GenericSequence[int]] = None, + *, + prompt_embeds: Optional[torch.Tensor] = None, ) -> "SequenceData": """ Construct a :class:`SequenceData` instance from prompt and output @@ -219,13 +221,15 @@ def from_seqs( prompt_token_ids) if output_token_ids is None: - return SequenceData(prompt_token_ids_arr) + return SequenceData(prompt_token_ids_arr, + _prompt_embeds=prompt_embeds) output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, output_token_ids) return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr) + _output_token_ids=output_token_ids_arr, + _prompt_embeds=prompt_embeds) def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" @@ -305,7 +309,8 @@ def append_token_id(self, token_id: int, logprob: float) -> None: self._cumulative_logprob += logprob def get_len(self) -> int: - return len(self._output_token_ids) + len(self._prompt_token_ids) + return len(self._output_token_ids) + len(self._prompt_token_ids) + ( + len(self._prompt_embeds) if self._prompt_embeds is not None else 0) def get_prompt_len(self) -> int: return len(self._prompt_token_ids) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b26b6c3cdac7..311147351b94 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -365,8 +365,7 @@ def __init__( else: self.input_tokens = input_tokens or [] - self.inputs_embeds = (inputs_embeds - if inputs_embeds is not None else None) + self.inputs_embeds = inputs_embeds self.input_positions = input_positions or [] self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None @@ -543,13 +542,12 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_data.get_num_computed_tokens() # Compute tokens. - tokens = seq_data.get_token_ids()[context_len:seq_len] - if seq_data.prompt_embeds is not None and seq_data.get_output_len( - ) == 0: - prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] - else: - seq_data.prompt_embeds = None # release memory + if seq_data.prompt_embeds is None: + tokens = seq_data.get_token_ids()[context_len:seq_len] prompt_embeds = None + else: + tokens = [0] * (seq_len - context_len) + prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] token_types = seq_group_metadata.token_type_ids @@ -881,7 +879,7 @@ def build(self) -> ModelInputForGPU: else: inputs_embeds = torch.cat(inputs_embeds_, dim=0) - if not input_tokens: + if not input_tokens and inputs_embeds is None: # This may happen when all prefill requests hit # prefix caching and there is no decode request. return self.model_input_cls() From a64e62745c8833527368d03bb3cf62fde34bc5fb Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 25 Mar 2025 22:48:04 -0500 Subject: [PATCH 16/96] remove unrelated qwen2 changes Signed-off-by: Andrew Sansom --- vllm/model_executor/models/qwen2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index a6750f2025cf..c4d02e5ddeb1 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -437,7 +437,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - True, quant_config=quant_config, prefix=maybe_prefix( prefix, "lm_head")) @@ -460,9 +459,8 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, self.lm_head.bias) + inputs_embeds) return hidden_states def compute_logits( From 6ab349eba0727f6d3bda4a2ff01f83137c763489 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Wed, 26 Mar 2025 21:30:13 -0500 Subject: [PATCH 17/96] guard clause around fully consumed prompt embeds to avoid returning empty tensors instead of none Signed-off-by: Andrew Sansom --- tests/worker/test_model_runner.py | 2 +- vllm/worker/model_runner.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index c8bd2a3d287c..0c1193853aa0 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -53,7 +53,7 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio): seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: seq_data = SequenceData.from_seqs( - range(seq_len), + [], prompt_embeds=torch.rand(seq_len, 10), ) input_embeds_len += seq_len diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 311147351b94..fcb65dd726b0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -548,6 +548,10 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, else: tokens = [0] * (seq_len - context_len) prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] + if len(prompt_embeds) == 0: + # Sometimes the prompt_embeds can be fully processed, so the + # seq_data.prompt_embeds[context_len:seq_len] can be empty. + prompt_embeds = None token_types = seq_group_metadata.token_type_ids From 26c87840f4d19cd4ed4f9d3b6a0730bcb3357648 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Wed, 26 Mar 2025 22:38:57 -0500 Subject: [PATCH 18/96] use v0 for prompt embeds model runner tests Signed-off-by: Andrew Sansom --- tests/worker/test_model_runner.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 0c1193853aa0..c065d552bad2 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -35,7 +35,11 @@ def test_deepseek_mla_attn_backend_module(): @pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) @pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) -def test_prepare_prompt(batch_size, prompt_embeds_ratio): +def test_prepare_prompt(batch_size, prompt_embeds_ratio, monkeypatch): + if prompt_embeds_ratio > 0.0: + # Prompt Embeddings is only currently supported on V0 + monkeypatch.setenv("VLLM_USE_V1", "0") + model_runner = _create_model_runner( "facebook/opt-125m", max_num_batched_tokens=100000, @@ -164,7 +168,12 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio): @pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) @pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) -def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): +def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio, + monkeypatch): + if prompt_embeds_ratio > 0.0: + # Prompt Embeddings is only currently supported on V0 + monkeypatch.setenv("VLLM_USE_V1", "0") + model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -334,7 +343,11 @@ def distributed_init(): @pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize('prompt_embeds_ratio', [0.0, 0.5, 1.0]) def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, - distributed_init): + distributed_init, monkeypatch): + if prompt_embeds_ratio > 0.0: + # Prompt Embeddings is only currently supported on V0 + monkeypatch.setenv("VLLM_USE_V1", "0") + model_runner = _create_model_runner( "facebook/opt-125m", seed=0, From b71a13c2bef492f48f36c9d6ad8d4e6b2841b7df Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Wed, 2 Apr 2025 16:47:32 -0500 Subject: [PATCH 19/96] fix batching of input embeddings Signed-off-by: Andrew Sansom --- vllm/engine/async_llm_engine.py | 5 ++++- vllm/engine/llm_engine.py | 5 ++++- vllm/entrypoints/llm.py | 14 -------------- vllm/sequence.py | 3 +-- vllm/worker/model_runner.py | 13 +++++++++---- 5 files changed, 18 insertions(+), 22 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index cc85310c3a1f..ec1cffeb83d5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -492,7 +492,10 @@ async def add_request_async( if (isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None and not prompt.get("prompt_token_ids", None)): - prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0] + # We use the -2 dimension (instead of 0) in case a batched input + # of batch size 1 is passed in. + prompt["prompt_token_ids"] = [0 + ] * prompt["prompt_embeds"].shape[-2] if self.tokenizer is not None: tokenizer = await self.get_tokenizer_async(lora_request) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a02bfd28275b..038dcf9b9af6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -775,7 +775,10 @@ def add_request( if (isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None and not prompt.get("prompt_token_ids", None)): - prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0] + # We use the -2 dimension (instead of 0) in case a batched input + # of batch size 1 is passed in. + prompt["prompt_token_ids"] = [0 + ] * prompt["prompt_embeds"].shape[-2] if self.tokenizer is not None: self._validate_token_prompt( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 272942ecd817..1a037fe1f9ff 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -278,7 +278,6 @@ def generate( sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, *, - prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -295,7 +294,6 @@ def generate( sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None, prompt_token_ids: Optional[list[int]] = None, - prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -330,7 +328,6 @@ def generate( list[SamplingParams]]] = None, *, prompt_token_ids: list[int], - prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -348,7 +345,6 @@ def generate( list[SamplingParams]]] = None, *, prompt_token_ids: list[list[int]], - prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -364,7 +360,6 @@ def generate( prompts: None, sampling_params: None, prompt_token_ids: Union[list[int], list[list[int]]], - prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -385,7 +380,6 @@ def generate( sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, - prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -410,8 +404,6 @@ def generate( prompts and it is paired one by one with the prompt. prompt_token_ids: DEPRECATED. Token IDs for the prompts. If provided, the `prompts` will be ignored. - prompt_embeds: Optional tensor of prompt embeddings to use instead - of text prompts. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for @@ -455,12 +447,6 @@ def generate( parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], prompts) - # Handle prompt_embeds separately - # This is a simplified approach - you may need to adjust based on how - # prompt_embeds is used - if prompt_embeds is not None: - parsed_prompts.prompt_embeds = prompt_embeds - if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: raise ValueError( diff --git a/vllm/sequence.py b/vllm/sequence.py index 666a1806720d..4b71cf3f74a3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -309,8 +309,7 @@ def append_token_id(self, token_id: int, logprob: float) -> None: self._cumulative_logprob += logprob def get_len(self) -> int: - return len(self._output_token_ids) + len(self._prompt_token_ids) + ( - len(self._prompt_embeds) if self._prompt_embeds is not None else 0) + return len(self._output_token_ids) + len(self._prompt_token_ids) def get_prompt_len(self) -> int: return len(self._prompt_token_ids) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fcb65dd726b0..c4734da5ba83 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -874,14 +874,19 @@ def build(self) -> ModelInputForGPU: token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: inputs_embeds_.append( - inter_data.inputs_embeds.to(self.runner.device)) + inter_data.inputs_embeds.to( + dtype=self.runner.model_config.dtype, + device=self.runner.device)) inputs_embeds: Optional[torch.Tensor] if len(inputs_embeds_) == 0: inputs_embeds = None - elif len(inputs_embeds_) == 1: - inputs_embeds = inputs_embeds_[0] else: - inputs_embeds = torch.cat(inputs_embeds_, dim=0) + inputs_embeds = torch.cat([ + x.squeeze(dim=0) if x.dim() == 3 else x for x in inputs_embeds_ + ], + dim=0).to( + dtype=self.runner.model_config.dtype, + device=self.runner.device) if not input_tokens and inputs_embeds is None: # This may happen when all prefill requests hit From 4aa9ade4f6c9c8feebcc06aca9ac2180c4c12373 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Wed, 2 Apr 2025 16:52:09 -0500 Subject: [PATCH 20/96] style formatting Signed-off-by: Andrew Sansom --- tests/conftest.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 83d5bfecb42e..b3ee3f0462bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -765,12 +765,11 @@ def get_inputs( multi_modal_data["audio"] = audio text_prompt_kwargs = { - ("prompt" if isinstance(prompt, str) else "prompt_embeds"): prompt, + ("prompt" if isinstance(prompt, str) else "prompt_embeds"): + prompt, "multi_modal_data": multi_modal_data or None } - inputs.append( - TextPrompt(**text_prompt_kwargs) - ) + inputs.append(TextPrompt(**text_prompt_kwargs)) return inputs From e2c4c26d4a0b36b24bd4141b68c45746b7b6746d Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Wed, 2 Apr 2025 20:46:24 -0500 Subject: [PATCH 21/96] remove incorrect overload Signed-off-by: Andrew Sansom --- vllm/entrypoints/llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1a037fe1f9ff..fcb7947cb1bd 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -310,7 +310,6 @@ def generate( sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None, prompt_token_ids: Optional[list[list[int]]] = None, - prompt_embeds: Optional[torch.Tensor] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, From 26d108ac9c9aa88328b88c311857ba092ad0ff53 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Wed, 2 Apr 2025 20:48:05 -0500 Subject: [PATCH 22/96] remove incorrect overload Signed-off-by: Andrew Sansom --- vllm/entrypoints/llm.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fcb7947cb1bd..ddd4ee4c97e8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -7,7 +7,6 @@ from typing import Any, Callable, ClassVar, Optional, Union, cast, overload import cloudpickle -import torch import torch.nn as nn from tqdm import tqdm from typing_extensions import TypeVar, deprecated @@ -1239,7 +1238,6 @@ def _convert_v1_inputs( self, prompts: Optional[Union[str, list[str]]], prompt_token_ids: Optional[Union[list[int], list[list[int]]]], - prompt_embeds: Optional[torch.Tensor] = None, ): # skip_tokenizer_init is now checked in engine @@ -1277,8 +1275,6 @@ def _convert_v1_inputs( parsed_prompts.append(item) - # We don't need to handle prompt_embeds here since it's handled in the - # generate method return parsed_prompts def _validate_and_add_requests( From af204355160c8f6b6b7f6072934b21d0b7496598 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 3 Apr 2025 21:31:15 -0500 Subject: [PATCH 23/96] Update representations Signed-off-by: Andrew Sansom Co-authored-by: Cyrus Leung --- vllm/sequence.py | 5 ++--- vllm/worker/model_runner.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 4b71cf3f74a3..4bad3a3b94a4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -402,7 +402,7 @@ def __repr__(self) -> str: return ( f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " - f"prompt_embeds={getattr(self._prompt_embeds, 'shape', None)}, " + f"prompt_embeds.shape={getattr(self._prompt_embeds, 'shape', None)}, " f"output_token_ids={self.output_token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()})") @@ -441,8 +441,7 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.data = SequenceData.from_seqs(self.prompt_token_ids) - self.data.prompt_embeds = self.inputs.prompt_embeds + self.data = SequenceData.from_seqs(self.prompt_token_ids, prompt_embeds=self.inputs.prompt_embeds self.output_logprobs: SampleLogprobs = [] self.output_text = "" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c4734da5ba83..b364c17d9b9b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -421,7 +421,7 @@ def __repr__(self) -> str: f"computed_block_nums={self.computed_block_nums}, " f"n_seqs={self.n_seqs}, " f"input_tokens={self.input_tokens}, " - f"inputs_embeds={getattr(self.inputs_embeds, 'shape', None)}, " + f"inputs_embeds.shape={getattr(self.inputs_embeds, 'shape', None)}, " f"input_positions={self.input_positions}, " f"token_types={self.token_types}, " f"mrope_input_positions={self.mrope_input_positions}, " From 25aaf3fe63a97637ca43c67c338cff60e4b88df1 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 3 Apr 2025 21:36:58 -0500 Subject: [PATCH 24/96] remove unrelated changes to docs Signed-off-by: Andrew Sansom --- vllm/entrypoints/llm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ddd4ee4c97e8..10d96a1b3766 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -400,13 +400,10 @@ def generate( When it is a single value, it is applied to every prompt. When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. - prompt_token_ids: DEPRECATED. Token IDs for the prompts. If - provided, the `prompts` will be ignored. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for generation, if any. - guided_options_request: Options for guided decoding, if any. priority: The priority of the requests, if any. Only applicable when priority scheduling policy is enabled. From bc0586016b98ae19baf31abd0ccce7e3e6d66003 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 3 Apr 2025 21:41:07 -0500 Subject: [PATCH 25/96] remove unrelated typing change Signed-off-by: Andrew Sansom --- vllm/entrypoints/llm.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 10d96a1b3766..f39b011c9301 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1352,12 +1352,8 @@ def _add_guided_params( raise ValueError("Cannot set both guided_options_request and " "params.guided_decoding.") - json = (guided_options.guided_json.model_dump_json() if - (guided_options.guided_json is not None - and not isinstance(guided_options.guided_json, (dict, str))) - else guided_options.guided_json) params.guided_decoding = GuidedDecodingParams( - json=json, + json=guided_options.guided_json, regex=guided_options.guided_regex, choice=guided_options.guided_choice, grammar=guided_options.guided_grammar, From b55800d477babafb51d87a0a6392a44d65a71a71 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 3 Apr 2025 22:13:57 -0500 Subject: [PATCH 26/96] fix missing syntax Signed-off-by: Andrew Sansom --- vllm/sequence.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 4bad3a3b94a4..e3deb7b4a748 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -399,13 +399,13 @@ def stage(self) -> SequenceStage: return self._stage def __repr__(self) -> str: - return ( - f"SequenceData(" - f"prompt_token_ids={self._prompt_token_ids}, " - f"prompt_embeds.shape={getattr(self._prompt_embeds, 'shape', None)}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"get_num_computed_tokens={self.get_num_computed_tokens()})") + return (f"SequenceData(" + f"prompt_token_ids={self._prompt_token_ids}, " + f"prompt_embeds.shape=" + f"{getattr(self._prompt_embeds, 'shape', None)}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"get_num_computed_tokens={self.get_num_computed_tokens()})") class Sequence: @@ -441,7 +441,8 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.data = SequenceData.from_seqs(self.prompt_token_ids, prompt_embeds=self.inputs.prompt_embeds + self.data = SequenceData.from_seqs( + self.prompt_token_ids, prompt_embeds=self.inputs.prompt_embeds) self.output_logprobs: SampleLogprobs = [] self.output_text = "" From be42a17ef79498421e9eb14339c2d2859c285d97 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 3 Apr 2025 22:56:04 -0500 Subject: [PATCH 27/96] do not schedule prompt embeds and non-prompt embeds in the same batch Signed-off-by: Andrew Sansom --- vllm/core/scheduler.py | 10 ++++++++++ vllm/sequence.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index cf85a2135c81..9882eb7d1b4e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1071,6 +1071,7 @@ def _schedule_prefills( ) ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] + using_prompt_embeds: bool = False waiting_queue = self.waiting @@ -1138,6 +1139,15 @@ def _schedule_prefills( waiting_queue.popleft() continue + # We cannot mix sequence groups that use prompt embeds and + # those that do not. + if len(seq_groups) == 0: + using_prompt_embeds = seq_group.uses_prompt_embeds() + if using_prompt_embeds != seq_group.uses_prompt_embeds(): + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id diff --git a/vllm/sequence.py b/vllm/sequence.py index e3deb7b4a748..aeb876431ca5 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -916,6 +916,10 @@ def __repr__(self) -> str: f"sampling_params={self.sampling_params}, " f"num_seqs={len(self.seqs)})") + def uses_prompt_embeds(self) -> bool: + """Returns True if the sequence group uses input embeds.""" + return any(seq.data.prompt_embeds is not None for seq in self.seqs) + class SequenceGroupMetadataDelta( msgspec.Struct, From c8fcfe41079c808fe84001c0ef797a42a5723574 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 4 Apr 2025 09:58:55 -0500 Subject: [PATCH 28/96] fix style linelength Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b364c17d9b9b..b11646901307 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -412,24 +412,24 @@ def __post_init__(self): self.lora_prompt_mapping = [] def __repr__(self) -> str: - return ( - f"InterDataForSeqGroup(" - f"request_id={self.request_id}, " - f"seq_ids={self.seq_ids}, " - f"is_prompt={self.is_prompt}, " - f"block_tables={self.block_tables}, " - f"computed_block_nums={self.computed_block_nums}, " - f"n_seqs={self.n_seqs}, " - f"input_tokens={self.input_tokens}, " - f"inputs_embeds.shape={getattr(self.inputs_embeds, 'shape', None)}, " - f"input_positions={self.input_positions}, " - f"token_types={self.token_types}, " - f"mrope_input_positions={self.mrope_input_positions}, " - f"seq_lens={self.seq_lens}, " - f"orig_seq_lens={self.orig_seq_lens}, " - f"query_lens={self.query_lens}, " - f"context_lens={self.context_lens}, " - f"multi_modal_kwargs={self.multi_modal_kwargs}") + return (f"InterDataForSeqGroup(" + f"request_id={self.request_id}, " + f"seq_ids={self.seq_ids}, " + f"is_prompt={self.is_prompt}, " + f"block_tables={self.block_tables}, " + f"computed_block_nums={self.computed_block_nums}, " + f"n_seqs={self.n_seqs}, " + f"input_tokens={self.input_tokens}, " + f"inputs_embeds.shape=" + f"{getattr(self.inputs_embeds, 'shape', None)}, " + f"input_positions={self.input_positions}, " + f"token_types={self.token_types}, " + f"mrope_input_positions={self.mrope_input_positions}, " + f"seq_lens={self.seq_lens}, " + f"orig_seq_lens={self.orig_seq_lens}, " + f"query_lens={self.query_lens}, " + f"context_lens={self.context_lens}, " + f"multi_modal_kwargs={self.multi_modal_kwargs}") def gen_inter_data_builder(self, num_seqs: int): return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( From 1e359ae596768fc33b66aa894e61821da8ac0e8c Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 10 Apr 2025 21:05:47 -0500 Subject: [PATCH 29/96] propogate embeddings for sampled output tokens for decoding Signed-off-by: Andrew Sansom --- tests/worker/test_model_runner.py | 8 +++- vllm/engine/llm_engine.py | 6 ++- vllm/engine/output_processor/multi_step.py | 6 ++- vllm/engine/output_processor/single_step.py | 3 +- vllm/model_executor/layers/sampler.py | 7 ++- vllm/sequence.py | 51 +++++++++++++++++++-- vllm/spec_decode/multi_step_worker.py | 3 +- vllm/worker/model_runner.py | 34 +++++++++++--- 8 files changed, 99 insertions(+), 19 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index c065d552bad2..2a826b86aa9c 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -198,11 +198,13 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio, prompt_embeds=torch.rand(context_len, 10), ) input_embeds_len += context_len + output_embed = torch.rand(10) else: seq_data = SequenceData.from_seqs(range(context_len)) + output_embed = None seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. - seq_data.append_token_id(1, 0) + seq_data.append_token_id(1, 0, output_embed) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -399,9 +401,11 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, [], prompt_embeds=torch.rand(context_len, 10), ) + output_embed = torch.rand(10) else: seq_data = SequenceData.from_seqs(range(context_len)) - seq_data.append_token_id(1, 0) + output_embed = None + seq_data.append_token_id(1, 0, output_embed) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 038dcf9b9af6..924475be238a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1290,11 +1290,13 @@ def _advance_to_next_step( if self.scheduler_config.is_multi_step: is_prefill_append = seq.data.get_num_uncomputed_tokens( ) == 0 - seq.append_token_id(sample.output_token, sample.logprobs) + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) if not is_prefill_append: seq_group.update_num_computed_tokens(1) else: - seq.append_token_id(sample.output_token, sample.logprobs) + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 4c5d78a43df6..1a5971830837 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -167,6 +167,7 @@ def _process_seq_outputs(self, seq: Sequence, sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples] + output_embeds = [sample.output_embed for sample in valid_samples] # Truncate to max_tokens if necessary. remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + @@ -190,11 +191,12 @@ def _process_seq_outputs(self, seq: Sequence, is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 # Incrementally append tokens to the sequence, as if we had only one new # token. - for output_token_id, output_logprob in zip(output_token_ids, - output_logprobs): + for output_token_id, output_logprob, output_embed in zip( + output_token_ids, output_logprobs, output_embeds): seq.append_token_id( token_id=output_token_id, logprobs=output_logprob, + token_embed=output_embed, ) if is_prefill_sampled_token: diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4d96791a1f8a..b5b51bb25a86 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -119,7 +119,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, sample = outputs.samples[0] seq = seq_group.first_seq if not is_async: - seq.append_token_id(sample.output_token, sample.logprobs) + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1ee1332ac45e..9368992b24fe 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -110,6 +110,11 @@ class SamplerOutput( # 'broadcasted' to all other PP ranks for next step. sampled_token_ids_cpu: Optional[torch.Tensor] = None + # On-device tensor containing the sampled token embeddings (embeddings + # corresponding to the sampled token ids). Used when prompt embeddings are + # specified in lieu of prompt token ids or text. + sampled_token_embeds: Optional[torch.Tensor] = None + # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None @@ -183,7 +188,7 @@ def __init__(self): # Whether or not the SamplerOutput should have on-device tensors # containing the sampled token ids and probabilities. This is used by - # speculative decoding. + # speculative decoding and when prompt embeddings are specified. self.include_gpu_probs_tensor = False self.should_modify_greedy_probs_inplace = False diff --git a/vllm/sequence.py b/vllm/sequence.py index aeb876431ca5..fb6ac29393b6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -167,6 +167,7 @@ class SequenceData(msgspec.Struct, default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) _prompt_embeds: Optional[torch.Tensor] = None + _output_embeds: Optional[torch.Tensor] = None ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 @@ -178,6 +179,7 @@ class SequenceData(msgspec.Struct, _num_cached_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL _cached_all_token_ids: list[int] = msgspec.field(default_factory=list) + _cached_all_token_embeds: Optional[torch.Tensor] = None # It is used to get delta input. It is reset when `get_delta_and_reset` # is called. @@ -237,6 +239,8 @@ def __post_init__(self) -> None: self._prompt_token_ids_tuple: tuple[int, ...] = tuple( self._prompt_token_ids) self._update_cached_all_tokens() + if self._prompt_embeds is not None: + self._update_cached_all_token_embeds() def _update_cached_all_tokens(self): assert isinstance(self._prompt_token_ids, array) @@ -244,6 +248,13 @@ def _update_cached_all_tokens(self): self._cached_all_token_ids: list[int] = list(self._prompt_token_ids + self._output_token_ids) + def _update_cached_all_token_embeds(self): + assert isinstance(self._prompt_embeds, torch.Tensor) + self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds + if self._output_embeds is not None: + self._cached_all_token_embeds = torch.cat( + (self._cached_all_token_embeds, self._output_embeds), dim=0) + @property def cumulative_logprob(self) -> float: return self._cumulative_logprob @@ -276,6 +287,16 @@ def output_token_ids(self, new_output_token_ids) self._update_cached_all_tokens() + @property + def output_token_embeds(self) -> tuple[int, ...]: + return tuple(self._output_token_ids) + + @output_token_embeds.setter + def output_token_embeds(self, + new_output_token_embeds: torch.Tensor) -> None: + self._output_token_ids = new_output_token_embeds + self._update_cached_all_token_embeds() + @property def output_token_ids_array(self) -> array: """Return the prompt token ids in array type. @@ -293,6 +314,7 @@ def prompt_embeds(self) -> Optional[torch.Tensor]: @prompt_embeds.setter def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None: self._prompt_embeds = prompt_embeds + self._update_cached_all_token_embeds() @property def mrope_position_delta(self) -> Optional[int]: @@ -302,11 +324,26 @@ def mrope_position_delta(self) -> Optional[int]: def mrope_position_delta(self, new_mrope_position_delta): self._mrope_position_delta = new_mrope_position_delta - def append_token_id(self, token_id: int, logprob: float) -> None: + def append_token_id(self, + token_id: int, + logprob: float, + token_embed: Optional[torch.Tensor] = None) -> None: self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) self._cumulative_logprob += logprob + if token_embed is not None: + # Do not pass in with batch or sequence dimensions + assert token_embed.ndim == 1 + token_embed = token_embed.detach().cpu().unsqueeze(0) + if self._output_embeds is None: + self._output_embeds = token_embed + else: + self._output_embeds = torch.cat( + (self._output_embeds, token_embed), dim=0) + assert self._cached_all_token_embeds is not None + self._cached_all_token_embeds = torch.cat( + (self._cached_all_token_embeds, token_embed), dim=0) def get_len(self) -> int: return len(self._output_token_ids) + len(self._prompt_token_ids) @@ -320,6 +357,9 @@ def get_output_len(self) -> int: def get_token_ids(self) -> list[int]: return self._cached_all_token_ids + def get_token_embeddings(self) -> Optional[torch.Tensor]: + return self._cached_all_token_embeds + def get_prefix_token_ids( self, num_tokens: int ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]: @@ -573,11 +613,12 @@ def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id(self, token_id: int, logprobs: dict[int, - Logprob]) -> None: + def append_token_id(self, token_id: int, logprobs: dict[int, Logprob], + token_embed: Optional[torch.Tensor]) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) + self.data.append_token_id(token_id, logprobs[token_id].logprob, + token_embed) def get_len(self) -> int: return self.data.get_len() @@ -1077,10 +1118,12 @@ class SequenceOutput( parent_seq_id: int output_token: int logprobs: dict[int, Logprob] + output_embed: Optional[torch.Tensor] = None def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}, " + f"output_embed.shape={self.output_embed.shape}" f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index d8d54918fa98..324e226dde6d 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -278,7 +278,8 @@ def _append_new_tokens( else: count += 1 - seq.append_token_id(token_id, token_logprob.logprob) + seq.append_token_id(token_id, token_logprob.logprob, + seq_output.output_embed) seq.update_num_computed_tokens(1) @staticmethod diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b11646901307..57370dd3476e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -34,7 +34,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import supports_lora, supports_multimodal @@ -547,11 +547,14 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, prompt_embeds = None else: tokens = [0] * (seq_len - context_len) - prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] - if len(prompt_embeds) == 0: - # Sometimes the prompt_embeds can be fully processed, so the - # seq_data.prompt_embeds[context_len:seq_len] can be empty. - prompt_embeds = None + prompt_embeds = seq_data.get_token_embeddings( + )[context_len:seq_len] + # # prompt_embeds = seq_data.prompt_embeds + # prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] + # # if len(prompt_embeds) == 0: + # # # Sometimes the prompt_embeds can be fully processed, so the + # # # seq_data.prompt_embeds[context_len:seq_len] can be empty. + # # prompt_embeds = None token_types = seq_group_metadata.token_type_ids @@ -887,6 +890,7 @@ def build(self) -> ModelInputForGPU: dim=0).to( dtype=self.runner.model_config.dtype, device=self.runner.device) + assert len(inputs_embeds) == len(input_tokens) if not input_tokens and inputs_embeds is None: # This may happen when all prefill requests hit @@ -1885,6 +1889,12 @@ def execute_model( model_input.async_callback() # Sample the next token. + assert isinstance(self.model.sampler, Sampler) + original_include_gpu_probs_tensor = \ + self.model.sampler.include_gpu_probs_tensor + if model_input.inputs_embeds is not None: + self.model.sampler.include_gpu_probs_tensor = True + output: SamplerOutput = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, @@ -1906,6 +1916,18 @@ def execute_model( output.model_forward_time = (orig_model_forward_time + model_forward_time) + if model_input.inputs_embeds is not None: + self.model.sampler.include_gpu_probs_tensor = \ + original_include_gpu_probs_tensor + if output.sampled_token_ids is not None: + output.sampled_token_embeds = self.model.get_input_embeddings( + output.sampled_token_ids.squeeze(1)) + + for token_embed, sequence_group_output in zip( + output.sampled_token_embeds, output.outputs): + assert len(sequence_group_output.samples) == 1 + sequence_group_output.samples[0].output_embed = token_embed + if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None From 59fbe70213477c3c2921c28b96ed3ce65efd9958 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 10 Apr 2025 21:57:45 -0500 Subject: [PATCH 30/96] fix type check Signed-off-by: Andrew Sansom --- vllm/sequence.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index fb6ac29393b6..536d40a25fd7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1121,9 +1121,11 @@ class SequenceOutput( output_embed: Optional[torch.Tensor] = None def __repr__(self) -> str: + output_embed_shape = \ + self.output_embed.shape if self.output_embed is not None else None return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}, " - f"output_embed.shape={self.output_embed.shape}" + f"output_embed.shape={output_embed_shape}" f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: From c152a3acf8dd4815fea745889db579c2810fdab5 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 10 Apr 2025 22:34:36 -0500 Subject: [PATCH 31/96] do not schedule decode sequence groups with batches containing both prompt embeds and not Signed-off-by: Andrew Sansom --- vllm/core/scheduler.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 9882eb7d1b4e..947fb43fe230 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1308,14 +1308,35 @@ def _schedule_default(self) -> SchedulerOutputs: if num_prefill_groups > 0: scheduled_seq_groups = prefills.seq_groups scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) + ignored_seq_groups_for_embeds = [] else: scheduled_seq_groups = running_scheduled.decode_seq_groups + if len(scheduled_seq_groups) > 0: + using_prompt_embeds = scheduled_seq_groups[ + 0].seq_group.uses_prompt_embeds() + ignored_seq_groups_for_embeds = [] + indices_ignored = [] + for i, seq_group in enumerate(scheduled_seq_groups): + if using_prompt_embeds !=\ + seq_group.seq_group.uses_prompt_embeds(): + ignored_seq_groups_for_embeds.append( + seq_group.seq_group) + indices_ignored.append(i) + if len(ignored_seq_groups_for_embeds) > 0: + scheduled_seq_groups = [ + group for i, group in enumerate(scheduled_seq_groups) + if i not in indices_ignored + ] + else: + ignored_seq_groups_for_embeds = [] + scheduled_seq_groups.extend(swapped_in.decode_seq_groups) blocks_to_copy = running_scheduled.blocks_to_copy blocks_to_copy.extend(swapped_in.blocks_to_copy) ignored_seq_groups = prefills.ignored_seq_groups + ignored_seq_groups.extend(ignored_seq_groups_for_embeds) ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) return SchedulerOutputs( From e7ab2a22f7c6226e0451bfe44c3aaedb762144e1 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 10 Apr 2025 22:43:06 -0500 Subject: [PATCH 32/96] fix type check Signed-off-by: Andrew Sansom --- vllm/core/scheduler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 947fb43fe230..7e143f156a6e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1305,6 +1305,7 @@ def _schedule_default(self) -> SchedulerOutputs: # Merge lists num_prefill_groups = len(prefills.seq_groups) + ignored_seq_groups_for_embeds: List[SequenceGroup] = [] if num_prefill_groups > 0: scheduled_seq_groups = prefills.seq_groups scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) @@ -1315,12 +1316,12 @@ def _schedule_default(self) -> SchedulerOutputs: using_prompt_embeds = scheduled_seq_groups[ 0].seq_group.uses_prompt_embeds() ignored_seq_groups_for_embeds = [] - indices_ignored = [] - for i, seq_group in enumerate(scheduled_seq_groups): + indices_ignored: List[int] = [] + for i, schedule_seq_group in enumerate(scheduled_seq_groups): if using_prompt_embeds !=\ - seq_group.seq_group.uses_prompt_embeds(): + schedule_seq_group.seq_group.uses_prompt_embeds(): ignored_seq_groups_for_embeds.append( - seq_group.seq_group) + schedule_seq_group.seq_group) indices_ignored.append(i) if len(ignored_seq_groups_for_embeds) > 0: scheduled_seq_groups = [ From 911adbe112ffb4b539da12f85495839979bd27de Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 11 Apr 2025 00:22:14 -0500 Subject: [PATCH 33/96] add default value to optional parameter Signed-off-by: Andrew Sansom --- vllm/sequence.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 536d40a25fd7..f86c24232012 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -613,8 +613,10 @@ def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id(self, token_id: int, logprobs: dict[int, Logprob], - token_embed: Optional[torch.Tensor]) -> None: + def append_token_id(self, + token_id: int, + logprobs: dict[int, Logprob], + token_embed: Optional[torch.Tensor] = None) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) self.data.append_token_id(token_id, logprobs[token_id].logprob, From 82d923d60c94b3d4bda3cf0533ee602e2897277c Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 14 Apr 2025 15:10:12 -0500 Subject: [PATCH 34/96] remove unused comments Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c5033eaa666a..a1bad9e7b22c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -549,12 +549,6 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, tokens = [0] * (seq_len - context_len) prompt_embeds = seq_data.get_token_embeddings( )[context_len:seq_len] - # # prompt_embeds = seq_data.prompt_embeds - # prompt_embeds = seq_data.prompt_embeds[context_len:seq_len] - # # if len(prompt_embeds) == 0: - # # # Sometimes the prompt_embeds can be fully processed, so the - # # # seq_data.prompt_embeds[context_len:seq_len] can be empty. - # # prompt_embeds = None token_types = seq_group_metadata.token_type_ids From c95147928ec768488c25170779b663f235a4b4b6 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 14 Apr 2025 21:39:22 -0500 Subject: [PATCH 35/96] properly pass in placeholder token ids when testing prompt embeds Signed-off-by: Andrew Sansom --- tests/worker/test_model_runner.py | 43 ++++++++++++++++++------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 2a826b86aa9c..13149799ce9e 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -50,19 +50,19 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio, monkeypatch): seq_lens: list[int] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = [] block_tables = {0: [1]} - input_embeds_len = 0 + expected_input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: seq_data = SequenceData.from_seqs( - [], + prompt_token_ids=[0] * seq_len, prompt_embeds=torch.rand(seq_len, 10), ) - input_embeds_len += seq_len + expected_input_embeds_len += seq_len else: - seq_data = SequenceData.from_seqs(range(seq_len)) + seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -138,11 +138,11 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio, monkeypatch): assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) - if input_embeds_len == 0: + if expected_input_embeds_len == 0: torch.testing.assert_close(input_tokens, input_positions) assert input_embeds is None else: - assert len(input_embeds) == input_embeds_len + assert len(input_embeds) == expected_input_embeds_len sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, @@ -187,20 +187,21 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio, context_lens: list[int] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = [] # Assume each seq group finishes prefill. - input_embeds_len = 0 + expected_input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) if random.random() < prompt_embeds_ratio: seq_data = SequenceData.from_seqs( - [], + prompt_token_ids=[0] * context_len, prompt_embeds=torch.rand(context_len, 10), ) - input_embeds_len += context_len + expected_input_embeds_len += context_len output_embed = torch.rand(10) else: - seq_data = SequenceData.from_seqs(range(context_len)) + seq_data = SequenceData.from_seqs( + prompt_token_ids=range(context_len)) output_embed = None seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. @@ -368,19 +369,20 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, block_tables = {0: [1]} prefill_batch_size = batch_size // 2 decode_batch_size = batch_size - prefill_batch_size - input_embeds_len = 0 + expected_input_embeds_len = 0 for i in range(prefill_batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: seq_data = SequenceData.from_seqs( - [], + prompt_token_ids=[0] * seq_len, prompt_embeds=torch.rand(seq_len, 10), ) - input_embeds_len += seq_len + expected_input_embeds_len += seq_len else: - seq_data = SequenceData.from_seqs(range(seq_len)) + seq_data = SequenceData.from_seqs( + prompt_token_ids=range(seq_len), ) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -398,13 +400,18 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, context_len = i % (model_runner.block_size - 1) + 1 if random.random() < prompt_embeds_ratio: seq_data = SequenceData.from_seqs( - [], + prompt_token_ids=[0] * context_len, prompt_embeds=torch.rand(context_len, 10), ) output_embed = torch.rand(10) + # This also iterates the expected input_embeds, because the model + # needs both the input and output embeddings passed into together + expected_input_embeds_len += 1 else: - seq_data = SequenceData.from_seqs(range(context_len)) + seq_data = SequenceData.from_seqs( + prompt_token_ids=range(context_len), ) output_embed = None + assert len(seq_data.prompt_token_ids) == context_len seq_data.append_token_id(1, 0, output_embed) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( @@ -434,10 +441,10 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, assert attn_metadata.num_prefills == prefill_batch_size assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_prefill_tokens == sum(seq_lens) - if input_embeds_len == 0: + if expected_input_embeds_len == 0: assert input_embeds is None else: - assert len(input_embeds) == input_embeds_len + assert len(input_embeds) == expected_input_embeds_len # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. From 01e1a6ebff38bda28d59f15ba0d15b206487a21e Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 14 Apr 2025 21:43:55 -0500 Subject: [PATCH 36/96] do not test mixed token_ids/prompt_embeds batches in the model_runner Signed-off-by: Andrew Sansom --- tests/worker/test_model_runner.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 13149799ce9e..8976b1c606ef 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import random - import pytest import torch @@ -34,9 +32,9 @@ def test_deepseek_mla_attn_backend_module(): @pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) -@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) -def test_prepare_prompt(batch_size, prompt_embeds_ratio, monkeypatch): - if prompt_embeds_ratio > 0.0: +@pytest.mark.parametrize("use_prompt_embeds", [True, False]) +def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch): + if use_prompt_embeds: # Prompt Embeddings is only currently supported on V0 monkeypatch.setenv("VLLM_USE_V1", "0") @@ -55,7 +53,7 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio, monkeypatch): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - if random.random() < prompt_embeds_ratio: + if use_prompt_embeds: seq_data = SequenceData.from_seqs( prompt_token_ids=[0] * seq_len, prompt_embeds=torch.rand(seq_len, 10), @@ -167,10 +165,9 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio, monkeypatch): @pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) -@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) -def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio, - monkeypatch): - if prompt_embeds_ratio > 0.0: +@pytest.mark.parametrize("use_prompt_embeds", [True, False]) +def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): + if use_prompt_embeds: # Prompt Embeddings is only currently supported on V0 monkeypatch.setenv("VLLM_USE_V1", "0") @@ -192,7 +189,7 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio, # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) - if random.random() < prompt_embeds_ratio: + if use_prompt_embeds: seq_data = SequenceData.from_seqs( prompt_token_ids=[0] * context_len, prompt_embeds=torch.rand(context_len, 10), @@ -344,10 +341,10 @@ def distributed_init(): @pytest.mark.parametrize("batch_size", list(range(2, 128, 3))) @pytest.mark.parametrize("enforce_eager", [True, False]) -@pytest.mark.parametrize('prompt_embeds_ratio', [0.0, 0.5, 1.0]) -def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, +@pytest.mark.parametrize('use_prompt_embeds', [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds, distributed_init, monkeypatch): - if prompt_embeds_ratio > 0.0: + if use_prompt_embeds: # Prompt Embeddings is only currently supported on V0 monkeypatch.setenv("VLLM_USE_V1", "0") @@ -374,7 +371,7 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - if random.random() < prompt_embeds_ratio: + if use_prompt_embeds: seq_data = SequenceData.from_seqs( prompt_token_ids=[0] * seq_len, prompt_embeds=torch.rand(seq_len, 10), @@ -398,7 +395,7 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 - if random.random() < prompt_embeds_ratio: + if use_prompt_embeds: seq_data = SequenceData.from_seqs( prompt_token_ids=[0] * context_len, prompt_embeds=torch.rand(context_len, 10), From 193ad5cbab2a386eb6abfcae00bfe1994a2a5749 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 14 Apr 2025 22:29:17 -0500 Subject: [PATCH 37/96] refactor cuda_prepare_decode test Signed-off-by: Andrew Sansom --- tests/worker/test_model_runner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 8976b1c606ef..cf8c345b99d2 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -184,7 +184,6 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): context_lens: list[int] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = [] # Assume each seq group finishes prefill. - expected_input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 @@ -194,7 +193,6 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): prompt_token_ids=[0] * context_len, prompt_embeds=torch.rand(context_len, 10), ) - expected_input_embeds_len += context_len output_embed = torch.rand(10) else: seq_data = SequenceData.from_seqs( @@ -271,8 +269,11 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs - torch.allclose(input_tokens, input_positions) - assert input_embeds is None + torch.testing.assert_close(input_tokens, input_positions) + if use_prompt_embeds: + assert len(input_embeds) == len(input_tokens) + else: + assert input_embeds is None # Verify Sampling expected_selected_token_indices = [] From 74bd9f46666a7e6c2f3141b6d312607073152f0d Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 15 Apr 2025 11:02:34 -0500 Subject: [PATCH 38/96] use correct expected input embeds length for prepare_decode_cuda_graph tests Signed-off-by: Andrew Sansom --- tests/worker/test_model_runner.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index cf8c345b99d2..4d260ec7b3b8 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -213,10 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, input_embeds, attn_metadata, slot_mapping = ( - model_input.input_tokens, model_input.input_positions, - model_input.inputs_embeds, model_input.attn_metadata, - model_input.attn_metadata.slot_mapping) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds + attn_metadata = model_input.attn_metadata + slot_mapping = attn_metadata.slot_mapping + assert len(slot_mapping) == len(input_tokens) expected_bs = model_runner.vllm_config.pad_for_cudagraph( @@ -261,7 +263,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): # block table's first index corresponds to each batch, meaning in # decoding it is each token. assert attn_metadata.block_tables.shape[0] == len(input_tokens) - # Block table's second dim correspondsd to each token's block number. + # Block table's second dim corresponds to each token's block number. # It is padded up to assert attn_metadata.block_tables.shape[1] == ( model_runner.get_max_block_per_batch()) @@ -269,9 +271,10 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs - torch.testing.assert_close(input_tokens, input_positions) if use_prompt_embeds: - assert len(input_embeds) == len(input_tokens) + expected_input_embeds_length = start_loc[-1] + assert len(input_embeds) == expected_input_embeds_length + assert expected_input_embeds_length <= expected_bs else: assert input_embeds is None From d949f1b0b3d91cf4eb3c7d549d49653fb7090362 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 15 Apr 2025 15:50:39 -0500 Subject: [PATCH 39/96] add scheduler test to ensure prompt embeds and prompt tokens are not co-mingled Signed-off-by: Andrew Sansom --- tests/core/test_scheduler.py | 73 +++++++++++++++++++++++++++++++++++- tests/core/utils.py | 7 +++- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 8bd64923fe22..fd5e0fdab48f 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -5,13 +5,14 @@ from unittest.mock import MagicMock import pytest # noqa +import torch from torch import Use # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroup +from vllm.sequence import SequenceGroup, SequenceStatus from .utils import (append_new_token, append_new_token_seq, append_new_token_seq_group, create_dummy_prompt, @@ -968,3 +969,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( ), "A partial prefix of C (4 tokens) should be prefilled, with the " "remaining tokens fit into 3 token budget (4-1 from the seqA). It will " "then be rounded down to 2 tokens on block size, thus 6 tokens in total." + + +def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds(): + """ + Test that the scheduler does not schedule batches with prompt tokens and + prompt embeddings co-mingled. + """ + block_size = 2 + max_seq_group = 3 + scheduler = initialize_scheduler( + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16, + max_num_seqs=max_seq_group, + max_model_len=100, + enable_prefix_caching=True, + ) + + # the odd indexed inputs should be passed in via embeddings, + # evens via token_ids + seq_length = 7 + embedding_size = 5 + num_seqs = 11 + seq_tokens = [] + seq_embeds = [] + for i in range(num_seqs): + if i % 2: + seq_tokens.append(list(range(seq_length))) + seq_embeds.append(None) + else: + seq_tokens.append([0] * seq_length) + seq_embeds.append(torch.rand(embedding_size)) + + seq_and_seq_groups = [ + create_dummy_prompt(f"{i}", + prompt_tokens=seq_tokens[i], + prompt_embeds=seq_embeds[i], + block_size=block_size) + for i in range(len(seq_tokens)) + ] + + for _, seq_group in seq_and_seq_groups: + scheduler.add_seq_group(seq_group) + + while not all(seq.is_finished() for seq, _ in seq_and_seq_groups): + unfinished_seq_groups = [ + seq_group for _, seq_group in seq_and_seq_groups + if not seq_group.is_finished() + ] + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) > 0 + batch_is_prompt_embeds = out.scheduled_seq_groups[ + 0].seq_group.uses_prompt_embeds() + expected_scheduled_seq_groups = [ + seq_group for seq_group in unfinished_seq_groups + if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds + ] + + # We should have as many scheduled groups as possible, without mixing + assert len(out.scheduled_seq_groups) == min( + max_seq_group, len(expected_scheduled_seq_groups)) + assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() == + batch_is_prompt_embeds + for scheduled_seq_group in out.scheduled_seq_groups) + + # Finish the scheduled groups + for scheduled_seq_group in out.scheduled_seq_groups: + for seq in scheduled_seq_group.seq_group.seqs: + seq.status = SequenceStatus.FINISHED_STOPPED + scheduler.free_finished_seq_groups() diff --git a/tests/core/utils.py b/tests/core/utils.py index ea18b879a317..1caa074a9ef2 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -5,6 +5,8 @@ from collections.abc import Sequence as GenericSequence from typing import Any, Optional +import torch + from vllm import SamplingParams from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.inputs import EncoderDecoderInputs, token_inputs @@ -19,6 +21,7 @@ def create_dummy_prompt( block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_tokens: Optional[list[int]] = None, + prompt_embeds: Optional[torch.Tensor] = None, min_tokens: int = 0, max_tokens: int = 16, ) -> tuple[Sequence, SequenceGroup]: @@ -33,7 +36,9 @@ def create_dummy_prompt( prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence( int(request_id), - inputs=token_inputs(prompt_tokens, prompt=prompt_str), + inputs=token_inputs(prompt_token_ids=prompt_tokens, + prompt=prompt_str, + prompt_embeds=prompt_embeds), block_size=block_size, ) seq_group = SequenceGroup( From 62bbc88150fc230bb3f42a3a220e849a5184fd7c Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Wed, 16 Apr 2025 09:46:46 -0500 Subject: [PATCH 40/96] support inputs_embeds in compiled mode Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 43 +++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a1bad9e7b22c..ffd21fbb192d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1086,7 +1086,7 @@ def __init__( self.max_batchsize_to_capture = \ self.vllm_config.compilation_config.max_capture_size - self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ + self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) ] self.graph_memory_pool: Optional[Tuple[ @@ -1529,6 +1529,10 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long, device=self.device) + inputs_embeds = torch.zeros( + (max_batch_size, self.model_config.get_hidden_size()), + dtype=self.model_config.dtype, + device=self.device) if self.model_config.uses_mrope: input_positions = torch.tile(input_positions, (3, 1)).cuda(device=self.device) @@ -1568,13 +1572,15 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.parallel_config.pipeline_parallel_size): # Only rank 0 should print progress bar during capture cudagraph_capture_sizes = (tqdm( - self.vllm_config.compilation_config. - cudagraph_capture_sizes, + list( + itertools.product( + self.vllm_config.compilation_config. + cudagraph_capture_sizes, [True, False])), desc="Capturing CUDA graph shapes", ) if get_tensor_model_parallel_rank() == 0 else self.vllm_config.compilation_config. cudagraph_capture_sizes) - for batch_size in cudagraph_capture_sizes: + for batch_size, use_inputs_embeds in cudagraph_capture_sizes: attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size, @@ -1605,6 +1611,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: capture_inputs = { "input_ids": input_tokens[:batch_size], + "inputs_embeds": + inputs_embeds[:batch_size] + if use_inputs_embeds else None, "positions": input_positions[..., :batch_size], "intermediate_inputs": @@ -1641,8 +1650,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: virtual_engine): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[virtual_engine][batch_size] = ( - graph_runner) + self.graph_runners[virtual_engine][( + batch_size, use_inputs_embeds)] = (graph_runner) if self.lora_config: self._remove_dummy_loras() @@ -1774,8 +1783,9 @@ def execute_model( if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[virtual_engine][ - graph_batch_size] + use_inputs_embeds = model_input.inputs_embeds is not None + model_executable = self.graph_runners[virtual_engine][( + graph_batch_size, use_inputs_embeds)] if previous_hidden_states is not None: previous_hidden_states = torch.cat([ previous_hidden_states, @@ -1826,9 +1836,8 @@ def execute_model( self.vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, - **{ - "inputs_embeds": model_input.inputs_embeds, - } if model_input.inputs_embeds is not None else {}, + inputs_embeds=model_input.inputs_embeds + if model_input.inputs_embeds is not None else None, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, @@ -2015,6 +2024,7 @@ def graph(self): def capture( self, input_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor], positions: torch.Tensor, intermediate_inputs: Optional[IntermediateTensors], kv_caches: List[torch.Tensor], @@ -2031,6 +2041,7 @@ def capture( for _ in range(_NUM_WARMUP_ITERS): self.model( input_ids=input_ids, + inputs_embeds=inputs_embeds, positions=positions, intermediate_tensors=intermediate_inputs, **kwargs, @@ -2043,6 +2054,9 @@ def capture( with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): output_hidden_or_intermediate_states = self.model( input_ids=input_ids, + **({ + "inputs_embeds": inputs_embeds, + } if inputs_embeds is not None else {}), positions=positions, intermediate_tensors=intermediate_inputs, **kwargs, @@ -2070,6 +2084,9 @@ def capture( self.input_buffers = { "input_ids": input_ids, + **({ + "inputs_embeds": inputs_embeds, + } if inputs_embeds is not None else {}), "positions": positions, "kv_caches": @@ -2090,6 +2107,7 @@ def capture( def forward( self, input_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], **kwargs, @@ -2104,6 +2122,9 @@ def forward( # so the shape is not padded, we need to copy partial only self.input_buffers["positions"][:positions.shape[0]].copy_( positions, non_blocking=True) + if inputs_embeds is not None: + self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_( + inputs_embeds, non_blocking=True) if self.backend_name != "NO_ATTENTION": self.input_buffers["slot_mapping"].copy_( From 1d1ae4ba34c9b5d65e12b07405209ceccc168b34 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Wed, 16 Apr 2025 09:48:44 -0500 Subject: [PATCH 41/96] fix typing in test Signed-off-by: Andrew Sansom --- tests/core/test_scheduler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index fd5e0fdab48f..a5ba16898d89 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -2,6 +2,7 @@ import time from collections import deque +from typing import Optional from unittest.mock import MagicMock import pytest # noqa @@ -992,8 +993,8 @@ def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds(): seq_length = 7 embedding_size = 5 num_seqs = 11 - seq_tokens = [] - seq_embeds = [] + seq_tokens: list[list[int]] = [] + seq_embeds: list[Optional[torch.Tensor]] = [] for i in range(num_seqs): if i % 2: seq_tokens.append(list(range(seq_length))) From 1914676b01477a75d32c37ead6c6eff1b9d109f9 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Wed, 16 Apr 2025 10:28:28 -0500 Subject: [PATCH 42/96] use corrector operator precedence for handling empty strings Signed-off-by: Andrew Sansom --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 34765eaf1308..ece11376ce8e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -802,7 +802,7 @@ def generate( output_str = sample.text output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) - req_sample_output_strs.append(prompt_str or "" + output_str) + req_sample_output_strs.append((prompt_str or "") + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs From 70198f6253474e6a77a49479a0742a36fc073da4 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Wed, 16 Apr 2025 10:59:04 -0500 Subject: [PATCH 43/96] only test decoder models with input embeds in v0 backend Signed-off-by: Andrew Sansom --- .../decoder_only/language/test_models.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index e7fdbb26b546..63eb579ec72d 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -3,6 +3,7 @@ Run `pytest tests/models/test_models.py`. """ +import os import pytest import torch @@ -126,8 +127,9 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( - prompt_embeds, max_tokens, num_logprobs) + if os.getenv("VLLM_USE_V1") == "0": + vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( + prompt_embeds, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -135,12 +137,13 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, name_0="hf", name_1="vllm", ) - check_logprobs_close( - outputs_0_lst=vllm_outputs, - outputs_1_lst=vllm_outputs_from_embeds, - name_0="vllm", - name_1="vllm_from_embeds", - ) + if os.getenv("VLLM_USE_V1") == "0": + check_logprobs_close( + outputs_0_lst=vllm_outputs, + outputs_1_lst=vllm_outputs_from_embeds, + name_0="vllm", + name_1="vllm_from_embeds", + ) if use_rocm_aiter: # this is to ensure that vllm engine From 5595b4520b9ba4343b2e8adc0687e5230c2d4d1b Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 15:25:59 -0500 Subject: [PATCH 44/96] adjust type hints for modelinputforgpubuilder.build Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ffd21fbb192d..a2658484d57f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -861,25 +861,26 @@ def build(self) -> ModelInputForGPU: create on-device tensors. """ # Combine and flatten intermediate data. - input_tokens = [] - inputs_embeds_: list[torch.Tensor] = [] - token_types = [] + input_tokens = list[int]() + inputs_embeds_lst = list[torch.Tensor]() + token_types = list[int]() for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: - inputs_embeds_.append( + inputs_embeds_lst.append( inter_data.inputs_embeds.to( dtype=self.runner.model_config.dtype, device=self.runner.device)) inputs_embeds: Optional[torch.Tensor] - if len(inputs_embeds_) == 0: + if len(inputs_embeds_lst) == 0: inputs_embeds = None else: inputs_embeds = torch.cat([ - x.squeeze(dim=0) if x.dim() == 3 else x for x in inputs_embeds_ + x.squeeze(dim=0) if x.dim() == 3 else x + for x in inputs_embeds_lst ], dim=0).to( dtype=self.runner.model_config.dtype, From 3343d3e04da2e66f32e93690f619302e00013c49 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 15:27:53 -0500 Subject: [PATCH 45/96] simplify conditional logic Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a2658484d57f..1830defd3f54 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -286,10 +286,7 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.input_tokens[seq_id].clear() - if inputs_embeds is not None: - self.inputs_embeds = inputs_embeds - else: - self.inputs_embeds = None + self.inputs_embeds = inputs_embeds if input_positions: self.input_positions = input_positions From 5010ea02df5a41d75baee701c6ee1f2a131ec975 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 15:39:08 -0500 Subject: [PATCH 46/96] simplify compilation conditional logic Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1830defd3f54..45ff49fcc3e3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1568,17 +1568,22 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): + # We need to not only iterate over batch sizes, but also whether + # to use inputs_embeds or not, hence we use the cartesian + # product. + cudagraph_capture_sizes = self.vllm_config.compilation_config\ + .cudagraph_capture_sizes + cudagraph_inputs_embeds = (True, False) + compilation_cases = itertools.product( + cudagraph_capture_sizes, + cudagraph_inputs_embeds, + ) # Only rank 0 should print progress bar during capture - cudagraph_capture_sizes = (tqdm( - list( - itertools.product( - self.vllm_config.compilation_config. - cudagraph_capture_sizes, [True, False])), - desc="Capturing CUDA graph shapes", - ) if get_tensor_model_parallel_rank() == 0 else - self.vllm_config.compilation_config. - cudagraph_capture_sizes) - for batch_size, use_inputs_embeds in cudagraph_capture_sizes: + if get_tensor_model_parallel_rank() == 0: + compilation_cases = tqdm( + list(compilation_cases), + desc="Capturing CUDA graph shapes") + for batch_size, use_inputs_embeds in compilation_cases: attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size, From 2075e538f351faf66e60f24c01a8b7bbc3189333 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 15:46:17 -0500 Subject: [PATCH 47/96] refactor decoder only language model tests to reduce number of times an environment variable is referenced Signed-off-by: Andrew Sansom --- tests/models/decoder_only/language/test_models.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 63eb579ec72d..d54c578e5acf 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -114,20 +114,21 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - prompt_embeds = [] + prompt_embeds = [] if os.getenv("VLLM_USE_V1") == "0" else None prompt_token_ids = [] for prompt in example_prompts: token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids.to( hf_model.model.device) prompt_token_ids.append(token_ids) - prompt_embeds.append( - hf_model.model.get_input_embeddings()(token_ids).squeeze(0)) + if prompt_embeds is not None: + prompt_embeds.append(hf_model.model.get_input_embeddings()( + token_ids).squeeze(0)) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - if os.getenv("VLLM_USE_V1") == "0": + if prompt_embeds is not None: vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( prompt_embeds, max_tokens, num_logprobs) @@ -137,7 +138,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, name_0="hf", name_1="vllm", ) - if os.getenv("VLLM_USE_V1") == "0": + if prompt_embeds is not None: check_logprobs_close( outputs_0_lst=vllm_outputs, outputs_1_lst=vllm_outputs_from_embeds, From 9a4fb3c085d7650d06e093ebceae60acd4273b4d Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 15:50:59 -0500 Subject: [PATCH 48/96] break up multiple assignments for readability Signed-off-by: Andrew Sansom --- tests/worker/test_model_runner.py | 36 +++++++++++++++---------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 4d260ec7b3b8..9432e3db6e6c 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -307,25 +307,24 @@ def test_empty_seq_group(): seq_group_metadata_list: list[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - ) + + input_tokens = model_input.input_tokens, + input_positions = model_input.input_positions, + attn_metadata = model_input.attn_metadata, + assert input_tokens is None assert input_positions is None assert attn_metadata is None model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - (input_tokens, input_positions, input_embeds, attn_metadata, - return_seq_lens) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.inputs_embeds, - model_input.attn_metadata, - model_input.seq_lens, - ) + + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + assert input_tokens is None assert input_positions is None assert input_embeds is None @@ -427,12 +426,11 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds, decode_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) - (input_tokens, input_positions, input_embeds, attn_metadata) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.inputs_embeds, - model_input.attn_metadata, - ) + + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds + attn_metadata = model_input.attn_metadata prefill_meta_actual = attn_metadata.prefill_metadata decode_meta_actual = attn_metadata.decode_metadata From 8ad40915da96028450b3a3e16febaa30a7608a64 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 15:57:58 -0500 Subject: [PATCH 49/96] update type hints in scheduler Signed-off-by: Andrew Sansom --- vllm/core/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7e143f156a6e..f8ebb820518e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1305,7 +1305,7 @@ def _schedule_default(self) -> SchedulerOutputs: # Merge lists num_prefill_groups = len(prefills.seq_groups) - ignored_seq_groups_for_embeds: List[SequenceGroup] = [] + ignored_seq_groups_for_embeds = list[SequenceGroup]() if num_prefill_groups > 0: scheduled_seq_groups = prefills.seq_groups scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) @@ -1316,7 +1316,7 @@ def _schedule_default(self) -> SchedulerOutputs: using_prompt_embeds = scheduled_seq_groups[ 0].seq_group.uses_prompt_embeds() ignored_seq_groups_for_embeds = [] - indices_ignored: List[int] = [] + indices_ignored = list[int]() for i, schedule_seq_group in enumerate(scheduled_seq_groups): if using_prompt_embeds !=\ schedule_seq_group.seq_group.uses_prompt_embeds(): From 9055daf7921d704deca0bffa29c9a9b5e312dff4 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 15:59:11 -0500 Subject: [PATCH 50/96] clear existing lists instead of instantiating new ones Signed-off-by: Andrew Sansom --- vllm/core/scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f8ebb820518e..f833470f3ef7 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1309,13 +1309,13 @@ def _schedule_default(self) -> SchedulerOutputs: if num_prefill_groups > 0: scheduled_seq_groups = prefills.seq_groups scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) - ignored_seq_groups_for_embeds = [] + ignored_seq_groups_for_embeds.clear() else: scheduled_seq_groups = running_scheduled.decode_seq_groups if len(scheduled_seq_groups) > 0: using_prompt_embeds = scheduled_seq_groups[ 0].seq_group.uses_prompt_embeds() - ignored_seq_groups_for_embeds = [] + ignored_seq_groups_for_embeds.clear() indices_ignored = list[int]() for i, schedule_seq_group in enumerate(scheduled_seq_groups): if using_prompt_embeds !=\ @@ -1329,7 +1329,7 @@ def _schedule_default(self) -> SchedulerOutputs: if i not in indices_ignored ] else: - ignored_seq_groups_for_embeds = [] + ignored_seq_groups_for_embeds.clear() scheduled_seq_groups.extend(swapped_in.decode_seq_groups) From 9a57acaf55e11222839bcb2e172a5f1119615554 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 16:30:47 -0500 Subject: [PATCH 51/96] preprocess tensors to handle batched/misshaped prompt embeds to avoid handling shape mismatches in the engine and model runner Signed-off-by: Andrew Sansom --- vllm/engine/llm_engine.py | 6 ++---- vllm/inputs/preprocess.py | 14 +++++++++++++- vllm/worker/model_runner.py | 10 +++------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 81084f0276cb..1142b0d92d2f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -776,10 +776,8 @@ def add_request( if (isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None and not prompt.get("prompt_token_ids", None)): - # We use the -2 dimension (instead of 0) in case a batched input - # of batch size 1 is passed in. - prompt["prompt_token_ids"] = [0 - ] * prompt["prompt_embeds"].shape[-2] + seq_len = prompt["prompt_embeds"].shape[0] + prompt["prompt_token_ids"] = [0] * seq_len if self.tokenizer is not None: self._validate_token_prompt( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index f77880f97b0d..0787e73d76f6 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -342,6 +342,18 @@ def _prompt_to_llm_inputs( token_type_ids = tokens_content.get("token_type_ids") multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + prompt_embeds = tokens_content.get("prompt_embeds") + + # prompt_embeds must be (seq_len, hidden_size), but if the user + # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), + # we can unambiguously process the intent by squeezing the batch + # dimension. + if prompt_embeds.ndim() == 3 and prompt_embeds.shape[0] == 1: + prompt_embeds = prompt_embeds.squeeze(dim=0) + + if prompt_embeds.ndim() != 2: + raise ValueError( + "prompt_embeds must be of shape (seq_len, hidden_size).") if multi_modal_data is not None and self._can_process_multimodal(): return self._process_multimodal( @@ -354,7 +366,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, - prompt_embeds=tokens_content.get("prompt_embeds"), + prompt_embeds=prompt_embeds, token_type_ids=token_type_ids, multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 45ff49fcc3e3..d0d3bd39b4a6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -875,13 +875,9 @@ def build(self) -> ModelInputForGPU: if len(inputs_embeds_lst) == 0: inputs_embeds = None else: - inputs_embeds = torch.cat([ - x.squeeze(dim=0) if x.dim() == 3 else x - for x in inputs_embeds_lst - ], - dim=0).to( - dtype=self.runner.model_config.dtype, - device=self.runner.device) + inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( + dtype=self.runner.model_config.dtype, + device=self.runner.device) assert len(inputs_embeds) == len(input_tokens) if not input_tokens and inputs_embeds is None: From bbfb0f0ee16b63bb273af7676213ff2f5b2621c8 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 17:02:17 -0500 Subject: [PATCH 52/96] use seperate Embedsprompt class for preprocessing inputs embeddings Signed-off-by: Andrew Sansom --- vllm/inputs/data.py | 19 +++++++----- vllm/inputs/parse.py | 15 ++++++--- vllm/inputs/preprocess.py | 64 +++++++++++++++++++++++++++++---------- 3 files changed, 71 insertions(+), 27 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 8be04b766a15..4286518637b2 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -20,9 +20,6 @@ class TextPrompt(TypedDict): prompt: str """The input text to be tokenized before passing to the model.""" - prompt_embeds: NotRequired[torch.Tensor] - """The embeddings of the prompt, if available.""" - multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, @@ -44,9 +41,6 @@ class TokensPrompt(TypedDict): prompt_token_ids: list[int] """A list of token IDs to pass to the model.""" - prompt_embeds: NotRequired[torch.Tensor] - """The embeddings of the prompt, if available.""" - token_type_ids: NotRequired[list[int]] """A list of token type IDs to pass to the cross encoder model.""" @@ -65,7 +59,18 @@ class TokensPrompt(TypedDict): """ -SingletonPrompt = Union[str, TextPrompt, TokensPrompt] +class EmbedsPrompt(TypedDict): + """Schema for a prompt provided via token embeddings.""" + + prompt_embeds: torch.Tensor + """The embeddings of the prompt.""" + prompt_token_ids: NotRequired[list[int]] + """A list of token IDs to pass to the model.""" + prompt: NotRequired[str] + """The input text to be tokenized before passing to the model.""" + + +SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ Set of possible schemas for a single prompt: diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index c426f36d9041..1eff27aeadc2 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -6,8 +6,9 @@ from vllm.utils import is_list_of -from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt) +from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs, + PromptType, SingletonInputs, SingletonPrompt, TextPrompt, + TokensPrompt) class ParsedText(TypedDict): @@ -84,9 +85,15 @@ class ParsedTokensPrompt(TypedDict): content: TokensPrompt +class ParsedEmbedsPrompt(TypedDict): + type: Literal['embeds'] + content: EmbedsPrompt + + def parse_singleton_prompt( prompt: SingletonPrompt, -) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: +) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, + ParsedEmbedsPrompt]: if isinstance(prompt, str): return ParsedStrPrompt(type="str", content=prompt) elif isinstance(prompt, dict): @@ -97,7 +104,7 @@ def parse_singleton_prompt( return ParsedTextPrompt(type="text", content=prompt) elif "prompt_embeds" in prompt: - return ParsedTokensPrompt(type="tokens", content=prompt) + return ParsedEmbedsPrompt(type="embeds", content=prompt) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 0787e73d76f6..01457109df9a 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -342,18 +342,6 @@ def _prompt_to_llm_inputs( token_type_ids = tokens_content.get("token_type_ids") multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - prompt_embeds = tokens_content.get("prompt_embeds") - - # prompt_embeds must be (seq_len, hidden_size), but if the user - # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), - # we can unambiguously process the intent by squeezing the batch - # dimension. - if prompt_embeds.ndim() == 3 and prompt_embeds.shape[0] == 1: - prompt_embeds = prompt_embeds.squeeze(dim=0) - - if prompt_embeds.ndim() != 2: - raise ValueError( - "prompt_embeds must be of shape (seq_len, hidden_size).") if multi_modal_data is not None and self._can_process_multimodal(): return self._process_multimodal( @@ -366,7 +354,6 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, - prompt_embeds=prompt_embeds, token_type_ids=token_type_ids, multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, @@ -396,11 +383,34 @@ def _prompt_to_llm_inputs( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - prompt_embeds=text_content.get("prompt_embeds"), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) + if parsed["type"] == "embeds": + prompt_embeds_content = parsed["content"] + + prompt_embeds = prompt_embeds_content["prompt_embeds"] + prompt = prompt_embeds_content.get("prompt") + prompt_token_ids = prompt_embeds_content.get("prompt_token_ids") + + # prompt_embeds must be (seq_len, hidden_size), but if the user + # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), + # we can unambiguously process the intent by squeezing the batch + # dimension. + if prompt_embeds.ndim() == 3 and prompt_embeds.shape[0] == 1: + prompt_embeds = prompt_embeds.squeeze(dim=0) + + if prompt_embeds.ndim() != 2: + raise ValueError( + "prompt_embeds must be of shape (seq_len, hidden_size).") + + return token_inputs( + prompt_token_ids=prompt_token_ids, + prompt=prompt, + prompt_embeds=prompt_embeds, + ) + assert_never(parsed) async def _prompt_to_llm_inputs_async( @@ -442,7 +452,6 @@ async def _prompt_to_llm_inputs_async( return token_inputs( prompt_token_ids=prompt_token_ids, - prompt_embeds=tokens_content.get("prompt_embeds"), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) @@ -471,11 +480,34 @@ async def _prompt_to_llm_inputs_async( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - prompt_embeds=tokens_content.get("prompt_embeds"), multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) + if parsed["type"] == "embeds": + prompt_embeds_content = parsed["content"] + + prompt_embeds = prompt_embeds_content["prompt_embeds"] + prompt = prompt_embeds_content.get("prompt") + prompt_token_ids = prompt_embeds_content.get("prompt_token_ids") + + # prompt_embeds must be (seq_len, hidden_size), but if the user + # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), + # we can unambiguously process the intent by squeezing the batch + # dimension. + if prompt_embeds.ndim() == 3 and prompt_embeds.shape[0] == 1: + prompt_embeds = prompt_embeds.squeeze(dim=0) + + if prompt_embeds.ndim() != 2: + raise ValueError( + "prompt_embeds must be of shape (seq_len, hidden_size).") + + return token_inputs( + prompt_token_ids=prompt_token_ids, + prompt=prompt, + prompt_embeds=prompt_embeds, + ) + assert_never(parsed) def _build_enc_dec_llm_inputs( From 933e56724526149a49a0e974ebacbc9e2e5b883c Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 17:13:48 -0500 Subject: [PATCH 53/96] fix typing Signed-off-by: Andrew Sansom --- vllm/inputs/parse.py | 7 +++++-- vllm/inputs/preprocess.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 1eff27aeadc2..99585fb99d34 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -106,11 +106,14 @@ def parse_singleton_prompt( elif "prompt_embeds" in prompt: return ParsedEmbedsPrompt(type="embeds", content=prompt) - raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") + raise TypeError( + "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: - return isinstance(prompt, dict) and "prompt_token_ids" in prompt + return isinstance( + prompt, dict + ) and "prompt_token_ids" in prompt and "prompt_embeds" not in prompt def is_explicit_encoder_decoder_prompt( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 01457109df9a..dc04ddd46566 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -391,7 +391,7 @@ def _prompt_to_llm_inputs( prompt_embeds_content = parsed["content"] prompt_embeds = prompt_embeds_content["prompt_embeds"] - prompt = prompt_embeds_content.get("prompt") + prompt_text = prompt_embeds_content.get("prompt") prompt_token_ids = prompt_embeds_content.get("prompt_token_ids") # prompt_embeds must be (seq_len, hidden_size), but if the user @@ -407,7 +407,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, - prompt=prompt, + prompt=prompt_text, prompt_embeds=prompt_embeds, ) From 4e0d12f62b86e2c243c820cc89c35e929bda890c Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 21:07:52 -0500 Subject: [PATCH 54/96] fix type errors Signed-off-by: Andrew Sansom --- vllm/inputs/data.py | 5 ++--- vllm/inputs/parse.py | 15 ++++++++------- vllm/inputs/preprocess.py | 21 +++++++++++---------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 4286518637b2..81cfda5ecbae 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -293,10 +293,9 @@ def token_type_ids(self) -> list[int]: def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs - if inputs["type"] == "token" or inputs["type"] == "multimodal": + if inputs["type"] == "embeds": return inputs.get("prompt_embeds") - - assert_never(inputs) # type: ignore[arg-type] + return None @cached_property def multi_modal_data(self) -> "MultiModalDataDict": diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 99585fb99d34..acec86361a51 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -97,15 +97,16 @@ def parse_singleton_prompt( if isinstance(prompt, str): return ParsedStrPrompt(type="str", content=prompt) elif isinstance(prompt, dict): - if "prompt_token_ids" in prompt: - return ParsedTokensPrompt(type="tokens", - content=prompt) # type: ignore + # Type ignores are because mypy does not correctly infer the TypedDicts + # Pyright does succeed. + if "prompt_embeds" in prompt: + return ParsedEmbedsPrompt( + type="embeds", content=prompt) # type: ignore[typeddict-item] + elif "prompt_token_ids" in prompt: + return ParsedTokensPrompt( + type="tokens", content=prompt) # type: ignore[typeddict-item] elif "prompt" in prompt: return ParsedTextPrompt(type="text", content=prompt) - - elif "prompt_embeds" in prompt: - return ParsedEmbedsPrompt(type="embeds", content=prompt) - raise TypeError( "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index dc04ddd46566..636dd9fd9866 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -391,23 +391,24 @@ def _prompt_to_llm_inputs( prompt_embeds_content = parsed["content"] prompt_embeds = prompt_embeds_content["prompt_embeds"] - prompt_text = prompt_embeds_content.get("prompt") - prompt_token_ids = prompt_embeds_content.get("prompt_token_ids") # prompt_embeds must be (seq_len, hidden_size), but if the user # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), # we can unambiguously process the intent by squeezing the batch # dimension. - if prompt_embeds.ndim() == 3 and prompt_embeds.shape[0] == 1: + if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1: prompt_embeds = prompt_embeds.squeeze(dim=0) - if prompt_embeds.ndim() != 2: + if prompt_embeds.ndim != 2: raise ValueError( "prompt_embeds must be of shape (seq_len, hidden_size).") + prompt_token_ids = prompt_embeds_content.get( + "prompt_token_ids", [0] * len(prompt_embeds)) + return token_inputs( prompt_token_ids=prompt_token_ids, - prompt=prompt_text, + prompt=prompt_embeds_content.get("prompt"), prompt_embeds=prompt_embeds, ) @@ -488,23 +489,23 @@ async def _prompt_to_llm_inputs_async( prompt_embeds_content = parsed["content"] prompt_embeds = prompt_embeds_content["prompt_embeds"] - prompt = prompt_embeds_content.get("prompt") - prompt_token_ids = prompt_embeds_content.get("prompt_token_ids") # prompt_embeds must be (seq_len, hidden_size), but if the user # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), # we can unambiguously process the intent by squeezing the batch # dimension. - if prompt_embeds.ndim() == 3 and prompt_embeds.shape[0] == 1: + if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1: prompt_embeds = prompt_embeds.squeeze(dim=0) - if prompt_embeds.ndim() != 2: + if prompt_embeds.ndim != 2: raise ValueError( "prompt_embeds must be of shape (seq_len, hidden_size).") + prompt_token_ids = prompt_embeds_content.get( + "prompt_token_ids", [0] * len(prompt_embeds)) return token_inputs( prompt_token_ids=prompt_token_ids, - prompt=prompt, + prompt=prompt_embeds_content.get("prompt"), prompt_embeds=prompt_embeds, ) From 9e6909e6d470e8d985650cb61b5238b054fc3782 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 21:52:28 -0500 Subject: [PATCH 55/96] fix mistaken type change Signed-off-by: Andrew Sansom --- vllm/inputs/data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 81cfda5ecbae..4286518637b2 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -293,9 +293,10 @@ def token_type_ids(self) -> list[int]: def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs - if inputs["type"] == "embeds": + if inputs["type"] == "token" or inputs["type"] == "multimodal": return inputs.get("prompt_embeds") - return None + + assert_never(inputs) # type: ignore[arg-type] @cached_property def multi_modal_data(self) -> "MultiModalDataDict": From 90b950adaae6229e03a79b4bbb6917a74e460d4a Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 18 Apr 2025 22:02:23 -0500 Subject: [PATCH 56/96] add missing type hint Signed-off-by: Andrew Sansom --- tests/models/decoder_only/language/test_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index d54c578e5acf..8805cc578f34 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -4,6 +4,7 @@ Run `pytest tests/models/test_models.py`. """ import os +from typing import Optional import pytest import torch @@ -114,7 +115,8 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - prompt_embeds = [] if os.getenv("VLLM_USE_V1") == "0" else None + prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv( + "VLLM_USE_V1") == "0" else None prompt_token_ids = [] for prompt in example_prompts: token_ids = hf_model.tokenizer(prompt, From 01d83f4d13a696cc412d4beebe7f4d3c184db71c Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Sat, 19 Apr 2025 20:27:13 -0500 Subject: [PATCH 57/96] add spaces for style Signed-off-by: Andrew Sansom --- vllm/inputs/data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 4286518637b2..8c3186d8f73b 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -64,8 +64,10 @@ class EmbedsPrompt(TypedDict): prompt_embeds: torch.Tensor """The embeddings of the prompt.""" + prompt_token_ids: NotRequired[list[int]] """A list of token IDs to pass to the model.""" + prompt: NotRequired[str] """The input text to be tokenized before passing to the model.""" From 69854521d90491d6342d509f749599204e177bdc Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Sat, 19 Apr 2025 20:34:09 -0500 Subject: [PATCH 58/96] seperate EmbedsInputs from TokenInputs and embeds_inputs from token_inputs to have separate structs for handling input embeddings Signed-off-by: Andrew Sansom --- tests/core/utils.py | 10 +++--- vllm/inputs/__init__.py | 8 +++-- vllm/inputs/data.py | 71 ++++++++++++++++++++++++++++++++------- vllm/inputs/preprocess.py | 20 +++++------ vllm/inputs/registry.py | 2 ++ 5 files changed, 80 insertions(+), 31 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 1caa074a9ef2..4238a52e3aba 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -9,7 +9,7 @@ from vllm import SamplingParams from vllm.core.scheduler import Scheduler, SchedulerOutputs -from vllm.inputs import EncoderDecoderInputs, token_inputs +from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import (Logprob, Sequence, SequenceGroup, SequenceGroupMetadata) @@ -34,11 +34,13 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) + inputs = token_inputs( + prompt_token_ids=prompt_tokens, + prompt=prompt_str) if prompt_embeds is None else embeds_inputs( + prompt_embeds=prompt_embeds, prompt=prompt_str) prompt = Sequence( int(request_id), - inputs=token_inputs(prompt_token_ids=prompt_tokens, - prompt=prompt_str, - prompt_embeds=prompt_embeds), + inputs=inputs, block_size=block_size, ) seq_group = SequenceGroup( diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 6f8f2cd758f7..6503b8b8251f 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, +from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, SingletonInputs, SingletonInputsAdapter, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - token_inputs, zip_enc_dec_prompts) + build_explicit_enc_dec_prompt, embeds_inputs, + to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import (DummyData, InputContext, InputProcessingContext, InputRegistry) @@ -22,7 +22,9 @@ "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "TokenInputs", + "EmbedsInputs", "token_inputs", + "embeds_inputs", "DecoderOnlyInputs", "EncoderDecoderInputs", "ProcessorInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 8c3186d8f73b..7a1e4ddcc6cf 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -152,9 +152,6 @@ class TokenInputs(TypedDict): prompt_token_ids: list[int] """The token IDs of the prompt.""" - prompt_embeds: NotRequired[torch.Tensor] - """The embeddings of the prompt, if available.""" - token_type_ids: NotRequired[list[int]] """The token type IDs of the prompt.""" @@ -198,7 +195,6 @@ def token_inputs( prompt_token_ids: list[int], token_type_ids: Optional[list[int]] = None, prompt: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_inputs: Optional["MultiModalKwargs"] = None, multi_modal_hashes: Optional[list[str]] = None, @@ -212,8 +208,6 @@ def token_inputs( inputs["prompt"] = prompt if token_type_ids is not None: inputs["token_type_ids"] = token_type_ids - if prompt_embeds is not None: - inputs["prompt_embeds"] = prompt_embeds if multi_modal_data is not None: inputs["multi_modal_data"] = multi_modal_data if multi_modal_inputs is not None: @@ -228,7 +222,42 @@ def token_inputs( return inputs -DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"] +class EmbedsInputs(TypedDict): + """Represents embeddings-based inputs.""" + + type: Literal["embeds"] + """The type of inputs.""" + + prompt_embeds: torch.Tensor + """The embeddings of the prompt, if available.""" + + prompt_token_ids: NotRequired[list[int]] + """The token IDs of the prompt.""" + + prompt: NotRequired[str] + """ + The original prompt text corresponding to the token IDs, if available. + """ + + +def embeds_inputs( + prompt_embeds: torch.Tensor, + prompt: Optional[str] = None, +) -> EmbedsInputs: + """Construct :class:`EmbedsInputs` from optional values.""" + inputs = EmbedsInputs( + type="embeds", + prompt_embeds=prompt_embeds, + prompt_token_ids=[0] * len(prompt_embeds), + ) + + if prompt is not None: + inputs["prompt"] = prompt + + return inputs + + +DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. @@ -250,7 +279,7 @@ class EncoderDecoderInputs(TypedDict): """The inputs for the decoder portion.""" -SingletonInputs = Union[TokenInputs, "MultiModalInputs"] +SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] """ A processed :class:`SingletonPrompt` which can be passed to :class:`vllm.sequence.Sequence`. @@ -268,7 +297,7 @@ class SingletonInputsAdapter: def prompt(self) -> Optional[str]: inputs = self.inputs - if inputs["type"] == "token" or inputs["type"] == "multimodal": + if inputs["type"] in ("token", "multimodal", "embeds"): return inputs.get("prompt") assert_never(inputs) # type: ignore[arg-type] @@ -277,7 +306,7 @@ def prompt(self) -> Optional[str]: def prompt_token_ids(self) -> list[int]: inputs = self.inputs - if inputs["type"] == "token" or inputs["type"] == "multimodal": + if inputs["type"] in ("token", "multimodal", "embeds"): return inputs.get("prompt_token_ids", []) assert_never(inputs) # type: ignore[arg-type] @@ -286,7 +315,7 @@ def prompt_token_ids(self) -> list[int]: def token_type_ids(self) -> list[int]: inputs = self.inputs - if inputs["type"] == "token" or inputs["type"] == "multimodal": + if inputs["type"] in ("token", "multimodal", "embeds"): return inputs.get("token_type_ids", []) assert_never(inputs) # type: ignore[arg-type] @@ -296,7 +325,10 @@ def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("prompt_embeds") + return None + + if inputs["type"] == "embeds": + return inputs["prompt_embeds"] assert_never(inputs) # type: ignore[arg-type] @@ -310,6 +342,9 @@ def multi_modal_data(self) -> "MultiModalDataDict": if inputs["type"] == "multimodal": return inputs.get("mm_kwargs", {}) + if inputs["type"] == "embeds": + return {} + assert_never(inputs) # type: ignore[arg-type] @cached_property @@ -322,6 +357,9 @@ def multi_modal_inputs(self) -> Union[dict, "MultiModalKwargs"]: if inputs["type"] == "multimodal": return inputs.get("mm_kwargs", {}) + if inputs["type"] == "embeds": + return {} + assert_never(inputs) # type: ignore[arg-type] @cached_property @@ -335,6 +373,9 @@ def multi_modal_hashes(self) -> list[str]: # only the case when we use MultiModalInputs return inputs.get("mm_hashes", []) # type: ignore[return-value] + if inputs["type"] == "embeds": + return [] + assert_never(inputs) # type: ignore[arg-type] @cached_property @@ -347,6 +388,9 @@ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": if inputs["type"] == "multimodal": return inputs.get("mm_placeholders", {}) + if inputs["type"] == "embeds": + return {} + assert_never(inputs) # type: ignore[arg-type] @cached_property @@ -359,6 +403,9 @@ def mm_processor_kwargs(self) -> dict[str, Any]: if inputs["type"] == "multimodal": return {} + if inputs["type"] == "embeds": + return {} + assert_never(inputs) # type: ignore[arg-type] diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 636dd9fd9866..5e212e841d0e 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -16,7 +16,8 @@ from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, - PromptType, SingletonInputs, SingletonPrompt, token_inputs) + PromptType, SingletonInputs, SingletonPrompt, embeds_inputs, + token_inputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) @@ -403,13 +404,9 @@ def _prompt_to_llm_inputs( raise ValueError( "prompt_embeds must be of shape (seq_len, hidden_size).") - prompt_token_ids = prompt_embeds_content.get( - "prompt_token_ids", [0] * len(prompt_embeds)) - - return token_inputs( - prompt_token_ids=prompt_token_ids, - prompt=prompt_embeds_content.get("prompt"), + return embeds_inputs( prompt_embeds=prompt_embeds, + prompt=prompt_embeds_content.get("prompt"), ) assert_never(parsed) @@ -501,12 +498,9 @@ async def _prompt_to_llm_inputs_async( raise ValueError( "prompt_embeds must be of shape (seq_len, hidden_size).") - prompt_token_ids = prompt_embeds_content.get( - "prompt_token_ids", [0] * len(prompt_embeds)) - return token_inputs( - prompt_token_ids=prompt_token_ids, - prompt=prompt_embeds_content.get("prompt"), + return embeds_inputs( prompt_embeds=prompt_embeds, + prompt=prompt_embeds_content.get("prompt"), ) assert_never(parsed) @@ -717,6 +711,8 @@ def _build_decoder_only_llm_inputs( prompt_inputs["prompt_token_ids"], prompt_adapter_request=prompt_adapter_request, ) + elif (prompt_inputs["type"] == "embeds"): + pass else: assert_never(prompt_inputs) # type: ignore[arg-type] diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 0579893e5d76..9df929391c71 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -438,6 +438,8 @@ def _ensure_mm_kwargs( elif inputs["type"] == "multimodal": # Be more strict in V2 assert "mm_kwargs" in inputs + elif inputs["type"] == "embeds": + pass else: assert_never(inputs["type"]) # type: ignore[arg-type] From e916551f3cb98e66eebfd683a3aefcadee13cf63 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Sat, 19 Apr 2025 20:37:33 -0500 Subject: [PATCH 59/96] fix docstrings for EmbedsInputs Signed-off-by: Andrew Sansom --- vllm/inputs/data.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 7a1e4ddcc6cf..5574fea7a193 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -229,10 +229,13 @@ class EmbedsInputs(TypedDict): """The type of inputs.""" prompt_embeds: torch.Tensor - """The embeddings of the prompt, if available.""" + """The embeddings of the prompt.""" prompt_token_ids: NotRequired[list[int]] - """The token IDs of the prompt.""" + """ + The token IDs of the prompt. Should always be a list of 0 of the same + length as prompt_embeds. + """ prompt: NotRequired[str] """ From 69f87250d5487f0954386a76167653bff82dcd1e Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Sat, 19 Apr 2025 20:39:54 -0500 Subject: [PATCH 60/96] fix typing for token_type_ids Signed-off-by: Andrew Sansom --- vllm/inputs/data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 5574fea7a193..a6b07beda499 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -318,8 +318,10 @@ def prompt_token_ids(self) -> list[int]: def token_type_ids(self) -> list[int]: inputs = self.inputs - if inputs["type"] in ("token", "multimodal", "embeds"): + if inputs["type"] in ("token", "multimodal"): return inputs.get("token_type_ids", []) + if inputs["type"] == "embeds": + return [] assert_never(inputs) # type: ignore[arg-type] From 9c2c89fb6098cfe25e06bdd6d7bf5e01e38a1f0f Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Sat, 19 Apr 2025 21:05:24 -0500 Subject: [PATCH 61/96] fix typing for embeds_tokens in InputRegistry and InputsAdapter Signed-off-by: Andrew Sansom --- vllm/inputs/data.py | 17 ++++++++++------- vllm/inputs/registry.py | 3 ++- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index a6b07beda499..c45ad9dca567 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -318,10 +318,11 @@ def prompt_token_ids(self) -> list[int]: def token_type_ids(self) -> list[int]: inputs = self.inputs - if inputs["type"] in ("token", "multimodal"): - return inputs.get("token_type_ids", []) - if inputs["type"] == "embeds": - return [] + if inputs["type"] in ("token", "multimodal", "embeds"): + # mypy incorrectly infer the type of inputs.get("token_type_ids") + # claiming it to be `object` and not `list[int]`. + return inputs.get("token_type_ids", + []) # type: ignore[return-value] assert_never(inputs) # type: ignore[arg-type] @@ -333,7 +334,7 @@ def prompt_embeds(self) -> Optional[torch.Tensor]: return None if inputs["type"] == "embeds": - return inputs["prompt_embeds"] + return inputs["prompt_embeds"] # type: ignore[typeddict-item] assert_never(inputs) # type: ignore[arg-type] @@ -372,7 +373,8 @@ def multi_modal_hashes(self) -> list[str]: inputs = self.inputs if inputs["type"] == "token": - return inputs.get("multi_modal_hashes", []) + return inputs.get("multi_modal_hashes", + []) # type: ignore[return-value] if inputs["type"] == "multimodal": # only the case when we use MultiModalInputs @@ -403,7 +405,8 @@ def mm_processor_kwargs(self) -> dict[str, Any]: inputs = self.inputs if inputs["type"] == "token": - return inputs.get("mm_processor_kwargs", {}) + return inputs.get("mm_processor_kwargs", + {}) # type: ignore[return-value] if inputs["type"] == "multimodal": return {} diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 9df929391c71..377e3d29c001 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -434,7 +434,8 @@ def _ensure_mm_kwargs( if inputs["type"] == "token": # In case the input processor for that model fails to set it if "mm_processor_kwargs" not in inputs: - inputs["mm_processor_kwargs"] = mm_processor_kwargs + inputs[ + "mm_processor_kwargs"] = mm_processor_kwargs # type: ignore[typeddict-unknown-key] elif inputs["type"] == "multimodal": # Be more strict in V2 assert "mm_kwargs" in inputs From 499dc6a7c67e3031af611031e7359b8d1e00992e Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 21 Apr 2025 08:23:03 -0500 Subject: [PATCH 62/96] remove prompts and prompt_token_ids from EmbedsPrompts Signed-off-by: Andrew Sansom --- vllm/engine/llm_engine.py | 4 +++- vllm/inputs/data.py | 26 ++++---------------------- 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 033fc376a162..ac4fbd1232d7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2062,10 +2062,12 @@ def _validate_model_input( tokenizer = (None if self.tokenizer is None else self.tokenizer.get_lora_tokenizer(lora_request)) - prompt_ids = prompt_inputs["prompt_token_ids"] + prompt_ids = prompt_inputs.get("prompt_token_ids", []) if not prompt_ids: if prompt_type == "encoder" and model_config.is_multimodal_model: pass # Mllama may have empty encoder inputs for text-only data + if prompt_inputs["type"] == "embeds": + pass else: raise ValueError(f"The {prompt_type} prompt cannot be empty") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index c45ad9dca567..5f57e2e264cf 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -65,12 +65,6 @@ class EmbedsPrompt(TypedDict): prompt_embeds: torch.Tensor """The embeddings of the prompt.""" - prompt_token_ids: NotRequired[list[int]] - """A list of token IDs to pass to the model.""" - - prompt: NotRequired[str] - """The input text to be tokenized before passing to the model.""" - SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ @@ -231,17 +225,6 @@ class EmbedsInputs(TypedDict): prompt_embeds: torch.Tensor """The embeddings of the prompt.""" - prompt_token_ids: NotRequired[list[int]] - """ - The token IDs of the prompt. Should always be a list of 0 of the same - length as prompt_embeds. - """ - - prompt: NotRequired[str] - """ - The original prompt text corresponding to the token IDs, if available. - """ - def embeds_inputs( prompt_embeds: torch.Tensor, @@ -251,12 +234,8 @@ def embeds_inputs( inputs = EmbedsInputs( type="embeds", prompt_embeds=prompt_embeds, - prompt_token_ids=[0] * len(prompt_embeds), ) - if prompt is not None: - inputs["prompt"] = prompt - return inputs @@ -309,9 +288,12 @@ def prompt(self) -> Optional[str]: def prompt_token_ids(self) -> list[int]: inputs = self.inputs - if inputs["type"] in ("token", "multimodal", "embeds"): + if inputs["type"] in ("token", "multimodal"): return inputs.get("prompt_token_ids", []) + if inputs["type"] == "embeds": + return [0] * len(inputs["prompt_embeds"]) + assert_never(inputs) # type: ignore[arg-type] @cached_property From 6712ba67beda25cfb2abd8f6ff0657a61f44eb2c Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 28 Apr 2025 11:14:00 -0500 Subject: [PATCH 63/96] fight mypy to get correct typing for not embeds prompts Signed-off-by: Andrew Sansom --- vllm/inputs/preprocess.py | 73 ++++++++++++++++++--------------------- vllm/sequence.py | 18 ++++++---- 2 files changed, 45 insertions(+), 46 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 4dfc9789a7c4..d8b8039ec762 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -6,6 +6,7 @@ from typing_extensions import assert_never +from vllm import envs from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -15,10 +16,11 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import TokenizerGroup -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, - PromptType, SingletonInputs, SingletonPrompt, embeds_inputs, - token_inputs) -from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt +from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs, + ProcessorInputs, PromptType, SingletonInputs, + SingletonPrompt, TokenInputs, embeds_inputs, token_inputs) +from .parse import (ParsedEmbedsPrompt, is_explicit_encoder_decoder_prompt, + parse_singleton_prompt) logger = init_logger(__name__) @@ -361,22 +363,7 @@ def _prompt_to_llm_inputs( ) if parsed["type"] == "embeds": - prompt_embeds_content = parsed["content"] - - prompt_embeds = prompt_embeds_content["prompt_embeds"] - - # prompt_embeds must be (seq_len, hidden_size), but if the user - # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), - # we can unambiguously process the intent by squeezing the batch - # dimension. - if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1: - prompt_embeds = prompt_embeds.squeeze(dim=0) - - if prompt_embeds.ndim != 2: - raise ValueError( - "prompt_embeds must be of shape (seq_len, hidden_size).") - - return embeds_inputs(prompt_embeds=prompt_embeds, ) + return self._process_prompt_embeds(parsed) assert_never(parsed) @@ -446,32 +433,36 @@ async def _prompt_to_llm_inputs_async( ) if parsed["type"] == "embeds": - prompt_embeds_content = parsed["content"] + return self._process_prompt_embeds(parsed) + + assert_never(parsed) - prompt_embeds = prompt_embeds_content["prompt_embeds"] + def _process_prompt_embeds(self, + parsed: ParsedEmbedsPrompt) -> EmbedsInputs: + if envs.VLLM_USE_V1: + raise ValueError("prompt_embeds is only available in V0.") - # prompt_embeds must be (seq_len, hidden_size), but if the user - # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), - # we can unambiguously process the intent by squeezing the batch - # dimension. - if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1: - prompt_embeds = prompt_embeds.squeeze(dim=0) + prompt_embeds_content = parsed["content"] - if prompt_embeds.ndim != 2: - raise ValueError( - "prompt_embeds must be of shape (seq_len, hidden_size).") + prompt_embeds = prompt_embeds_content["prompt_embeds"] - return embeds_inputs( - prompt_embeds=prompt_embeds, - prompt=prompt_embeds_content.get("prompt"), - ) + # prompt_embeds must be (seq_len, hidden_size), but if the user + # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), + # we can unambiguously process the intent by squeezing the batch + # dimension. + if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1: + prompt_embeds = prompt_embeds.squeeze(dim=0) - assert_never(parsed) + if prompt_embeds.ndim != 2: + raise ValueError( + "prompt_embeds must be of shape (seq_len, hidden_size).") + + return embeds_inputs(prompt_embeds=prompt_embeds) def _build_enc_dec_llm_inputs( self, - encoder_inputs: SingletonInputs, - decoder_inputs: Optional[SingletonInputs], + encoder_inputs: Union[TokenInputs, MultiModalInputs], + decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]], ) -> EncoderDecoderInputs: if (encoder_inputs["type"] == "token" or encoder_inputs["type"] == "multimodal"): @@ -479,6 +470,9 @@ def _build_enc_dec_llm_inputs( else: assert_never(encoder_inputs) # type: ignore[arg-type] + # Mypy does not correctly infer that EmbedsInputs is impossible + assert "prompt_token_ids" in encoder_inputs + if decoder_inputs is None: if self.model_config.hf_config.model_type == "whisper": # For Whisper models, the text prompt should go to the decoder. @@ -510,7 +504,8 @@ def _build_enc_dec_llm_inputs( def _separate_enc_dec_inputs_from_mm_processor_outputs( self, inputs: SingletonInputs, - decoder_inputs_to_override: Optional[SingletonInputs] = None, + decoder_inputs_to_override: Optional[Union[TokenInputs, + MultiModalInputs]] = None, ) -> tuple[SingletonInputs, SingletonInputs]: """ For encoder/decoder models only: diff --git a/vllm/sequence.py b/vllm/sequence.py index 82f4526873bd..8c4e7beae588 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -288,13 +288,12 @@ def output_token_ids(self, self._update_cached_all_tokens() @property - def output_token_embeds(self) -> tuple[int, ...]: - return tuple(self._output_token_ids) + def output_embeds(self) -> Optional[torch.Tensor]: + return self._output_embeds - @output_token_embeds.setter - def output_token_embeds(self, - new_output_token_embeds: torch.Tensor) -> None: - self._output_token_ids = new_output_token_embeds + @output_embeds.setter + def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None: + self._output_token_embeds = new_output_token_embeds self._update_cached_all_token_embeds() @property @@ -483,7 +482,8 @@ def __init__( self.data = SequenceData.from_seqs( self.prompt_token_ids, - prompt_embeds=self.inputs.get("prompt_embeds")) + prompt_embeds=self.inputs["prompt_embeds"] + if self.inputs["type"] == "embeds" else None) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -506,6 +506,8 @@ def n_blocks(self) -> int: @property def prompt(self) -> Optional[str]: + if self.inputs["type"] == "embeds": + return None return self.inputs.get("prompt") @property @@ -516,6 +518,8 @@ def prompt_token_ids(self) -> list[int]: @property def token_type_ids(self) -> list[int]: + if self.inputs["type"] == "embeds": + return [] return self.inputs.get("token_type_ids", []) @property From 740b290b9156b3d327c4efc70358b13ca8ecb09a Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 28 Apr 2025 17:30:52 -0500 Subject: [PATCH 64/96] remove incorrect call to embeds_inputs Signed-off-by: Andrew Sansom --- tests/core/utils.py | 2 +- vllm/inputs/data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 4238a52e3aba..84b0426b470b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -37,7 +37,7 @@ def create_dummy_prompt( inputs = token_inputs( prompt_token_ids=prompt_tokens, prompt=prompt_str) if prompt_embeds is None else embeds_inputs( - prompt_embeds=prompt_embeds, prompt=prompt_str) + prompt_embeds=prompt_embeds) prompt = Sequence( int(request_id), inputs=inputs, diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 09646db6cfd8..989146e09736 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -176,7 +176,7 @@ class EmbedsInputs(TypedDict): """The embeddings of the prompt.""" -def embeds_inputs(prompt_embeds: torch.Tensor, ) -> EmbedsInputs: +def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs: """Construct :class:`EmbedsInputs` from optional values.""" inputs = EmbedsInputs( type="embeds", From 8f9bd5135a901fd98deb7bc3f9b720540bc6c5e7 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 28 Apr 2025 20:20:37 -0500 Subject: [PATCH 65/96] wrestle with mypy and typeddict type narrowing Signed-off-by: Andrew Sansom --- vllm/inputs/data.py | 2 ++ vllm/inputs/parse.py | 38 ++++++++++++++++++++++++++++++++------ vllm/inputs/preprocess.py | 18 ++++++++++++++---- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 989146e09736..649a7faf7746 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -67,6 +67,7 @@ class EmbedsPrompt(TypedDict): - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) +- An embeddings prompt (:class:`EmbedsPrompt`) Note that "singleton" is as opposed to a data structure which encapsulates multiple prompts, i.e. of the sort @@ -127,6 +128,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) +- An embeddings prompt (:class:`EmbedsPrompt`) - A single data structure containing both an encoder and a decoder prompt (:class:`ExplicitEncoderDecoderPrompt`) """ diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index acec86361a51..397344e40230 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -6,9 +6,9 @@ from vllm.utils import is_list_of -from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs, - PromptType, SingletonInputs, SingletonPrompt, TextPrompt, - TokensPrompt) +from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt, + ProcessorInputs, PromptType, SingletonInputs, + SingletonPrompt, TextPrompt, TokensPrompt) class ParsedText(TypedDict): @@ -90,6 +90,26 @@ class ParsedEmbedsPrompt(TypedDict): content: EmbedsPrompt +@overload +def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt: + ... + + def parse_singleton_prompt( prompt: SingletonPrompt, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, @@ -112,9 +132,11 @@ def parse_singleton_prompt( def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: - return isinstance( - prompt, dict - ) and "prompt_token_ids" in prompt and "prompt_embeds" not in prompt + return isinstance(prompt, dict) and "prompt_token_ids" in prompt + + +def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]: + return isinstance(prompt, dict) and "prompt_embeds" in prompt def is_explicit_encoder_decoder_prompt( @@ -122,6 +144,10 @@ def is_explicit_encoder_decoder_prompt( return isinstance(prompt, dict) and "encoder_prompt" in prompt +def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]: + return isinstance(inputs, dict) and inputs["type"] == "embeds" + + def split_enc_dec_inputs( inputs: ProcessorInputs, ) -> tuple[Optional[SingletonInputs], SingletonInputs]: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index d8b8039ec762..e91753b807c3 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -19,8 +19,8 @@ from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, TokenInputs, embeds_inputs, token_inputs) -from .parse import (ParsedEmbedsPrompt, is_explicit_encoder_decoder_prompt, - parse_singleton_prompt) +from .parse import (ParsedEmbedsPrompt, is_embeds_inputs, + is_explicit_encoder_decoder_prompt, parse_singleton_prompt) logger = init_logger(__name__) @@ -596,6 +596,8 @@ def _process_encoder_decoder_prompt( # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model: + assert decoder_inputs is None or not is_embeds_inputs( + decoder_inputs) encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) @@ -608,9 +610,12 @@ def _process_encoder_decoder_prompt( inputs)) else: encoder_inputs = inputs - decoder_inputs = None + # Mypy does not do type inference well with TypedDicts with Literal + # values. + assert not is_embeds_inputs(encoder_inputs) + assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) async def _process_encoder_decoder_prompt_async( @@ -637,6 +642,8 @@ async def _process_encoder_decoder_prompt_async( # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model: + assert decoder_inputs is None or not is_embeds_inputs( + decoder_inputs) encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) @@ -649,9 +656,12 @@ async def _process_encoder_decoder_prompt_async( inputs)) else: encoder_inputs = inputs - decoder_inputs = None + # Mypy does not do type inference well with TypedDicts with Literal + # values. + assert not is_embeds_inputs(encoder_inputs) + assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) def _build_decoder_only_llm_inputs( From b8d36c69d86bde809302df678bd903278f408289 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 28 Apr 2025 20:30:04 -0500 Subject: [PATCH 66/96] wrestle with mypy and typeddict type narrowing Signed-off-by: Andrew Sansom --- vllm/inputs/preprocess.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index e91753b807c3..e66877c88d09 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -671,6 +671,9 @@ def _build_decoder_only_llm_inputs( ) -> DecoderOnlyInputs: if (prompt_inputs["type"] == "token" or prompt_inputs["type"] == "multimodal"): + # Mypy does not do type inference well with typedicts and Literal + # values + assert not is_embeds_inputs(prompt_inputs) prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"], prompt_adapter_request=prompt_adapter_request, From b764c1930e63fdfd30d2dbd13e52ced4eb4ddfcc Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 29 Apr 2025 09:32:39 -0500 Subject: [PATCH 67/96] support indexing graph runners that with inputs_embeds Signed-off-by: Andrew Sansom --- vllm/attention/backends/flashinfer.py | 14 +++++++++++--- vllm/spec_decode/draft_model_runner.py | 13 ++++++++++--- vllm/worker/enc_dec_model_runner.py | 15 +++++++++++---- vllm/worker/model_runner.py | 1 + vllm/worker/pooling_model_runner.py | 15 +++++++++++---- 5 files changed, 44 insertions(+), 14 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index d92177d58a48..37b20d0739f7 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -367,9 +367,17 @@ def begin_forward(self, model_input): # scheduled while CUDA graph mode is enabled. We don't run graph in that # case. if use_cuda_graph and is_decode: - batch_size = model_input.input_tokens.shape[0] - state = (self.runner.graph_runners[model_input.virtual_engine] - [batch_size].attn_state) + if model_input.inputs_embeds is None: + batch_size = model_input.input_tokens.shape[0] + state = ( + self.runner.graph_runners[model_input.virtual_engine][( + batch_size, False)].attn_state) + else: + batch_size = model_input.inputs_embeds.shape[0] + state = ( + self.runner.graph_runners[model_input.virtual_engine][( + batch_size, True)].attn_state) + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( ) model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 24095ef2a567..302bb322519a 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -242,9 +242,16 @@ def execute_model( # Get model if use_cuda_graph: - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = (self.graph_runners[model_input.virtual_engine] - [graph_batch_size]) + if model_input.inputs_embeds is None: + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, False)]) + else: + graph_batch_size = model_input.inputs_embeds.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, True)]) if previous_hidden_states is not None: hidden_states = torch.cat([ diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4df192a8727c..c7ecc75b387c 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -172,10 +172,17 @@ def execute_model( if (model_input.attn_metadata is not None and model_input.attn_metadata.prefill_metadata is None and model_input.attn_metadata.decode_metadata.use_cuda_graph): - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[ - model_input.virtual_engine][graph_batch_size] + if model_input.inputs_embeds is None: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, False)]) + else: + graph_batch_size = model_input.inputs_embeds.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, True)]) else: model_executable = self.model diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ab7dd07035e1..9f6a7b382d48 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1081,6 +1081,7 @@ def __init__( self.max_batchsize_to_capture = \ self.vllm_config.compilation_config.max_capture_size + # self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) ] diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index cbd5e2060cad..fdb7353f2f9c 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -84,10 +84,17 @@ def execute_model( # explore how to leverage it. if (prefill_meta is None and decode_meta is not None and decode_meta.use_cuda_graph): - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[virtual_engine][ - graph_batch_size] + if model_input.inputs_embeds is None: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, False)]) + else: + graph_batch_size = model_input.inputs_embeds.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, True)]) else: model_executable = self.model From 0e75db48bd5e5e32db9bad1e786505a7d895f2e4 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Mon, 28 Oct 2024 11:00:39 -0500 Subject: [PATCH 68/96] feat: completions using embeddings Signed-off-by: Nan2018 --- tests/entrypoints/openai/test_completion.py | 149 +++++++++++++++++- vllm/entrypoints/logger.py | 9 +- vllm/entrypoints/openai/protocol.py | 21 ++- vllm/entrypoints/openai/serving_completion.py | 6 +- vllm/entrypoints/openai/serving_engine.py | 90 ++++++++--- 5 files changed, 241 insertions(+), 34 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 1d9aa4972b70..197be1decfb6 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # imports for guided decoding tests +import base64 +import io import json import re import shutil @@ -11,10 +13,11 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio +import torch # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from vllm.transformers_utils.tokenizer import get_tokenizer @@ -31,6 +34,7 @@ PA_NUM_VIRTUAL_TOKENS = 8 GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] +CONFIG = AutoConfig.from_pretrained(MODEL_NAME) @pytest.fixture(scope="module") @@ -107,6 +111,14 @@ async def client(server): yield async_client +def create_dummy_embeds(num_tokens: int = 5) -> str: + """Create dummy embeddings and return them as base64 encoded string.""" + dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size) + buffer = io.BytesIO() + torch.save(dummy_embeds, buffer) + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + @pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras, then test prompt adapters @@ -143,6 +155,45 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, assert len(completion.choices[0].text) >= 1 assert completion.choices[0].prompt_logprobs is None + # test using prompt_embeds + encoded_embeds = create_dummy_embeds() + completion = await client.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + # test batch completion with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + assert len(completion.choices) == 2 + assert len(completion.choices[0].text) >= 1 + assert len(completion.choices[1].text) >= 1 + + # test error case: neither prompt nor prompt_embeds provided + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + max_tokens=5, + temperature=0.0, + ) + + # test error case: invalid prompt_embeds + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": "invalid_base64"}) + @pytest.mark.asyncio async def test_added_lora_tokens(client: openai.AsyncOpenAI): @@ -343,6 +394,55 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, assert chunk.choices[0].text assert "".join(chunks) == single_output + # test streaming with prompt_embeds + encoded_embeds = create_dummy_embeds() + single_completion = await client.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + single_output = single_completion.choices[0].text + + stream = await client.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": encoded_embeds}) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + # test batch streaming with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + stream = await client.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + chunks = [[], []] + finish_reason_count = 0 + async for chunk in stream: + chunks[chunk.choices[0].index].append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 2 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert len(chunks[0]) > 0 + assert len(chunks[1]) > 0 + @pytest.mark.asyncio @pytest.mark.parametrize( @@ -760,6 +860,53 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) > 5 + # test using prompt_embeds + encoded_embeds = create_dummy_embeds() + completion = await client.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": encoded_embeds}) + + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) > 5 + + # test batch completion with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + + assert len(completion.choices) == 2 + for choice in completion.choices: + logprobs = choice.logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) > 5 + @pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index ea5759152a22..d4655dd5e6ab 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -2,6 +2,8 @@ from typing import Optional, Union +import torch + from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams @@ -23,6 +25,7 @@ def log_inputs( request_id: str, prompt: Optional[str], prompt_token_ids: Optional[list[int]], + prompt_embeds: Optional[torch.Tensor], params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], @@ -39,6 +42,8 @@ def log_inputs( logger.info( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " + "prompt_embeds shape: %s, " "lora_request: %s, prompt_adapter_request: %s.", request_id, - prompt, params, prompt_token_ids, lora_request, - prompt_adapter_request) + prompt, params, prompt_token_ids, + prompt_embeds.shape if prompt_embeds is not None else None, + lora_request, prompt_adapter_request) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 015943762ab1..f13a98c69705 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -730,8 +730,9 @@ def check_generation_prompt(cls, data): class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create - model: Optional[str] = None - prompt: Union[list[int], list[list[int]], str, list[str]] + model: str + prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None best_of: Optional[int] = None echo: Optional[bool] = False frequency_penalty: Optional[float] = 0.0 @@ -1005,6 +1006,14 @@ def validate_stream_options(cls, data): return data + @model_validator(mode="before") + @classmethod + def validate_prompt_and_prompt_embeds(cls, data): + if data.get("prompt") is None and data.get("prompt_embeds") is None: + raise ValueError( + "At least one of `prompt` or `prompt_embeds` must be set.") + return data + class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -1622,9 +1631,9 @@ class TranscriptionRequest(OpenAIBaseModel): # doc: begin-transcription-extra-params stream: Optional[bool] = False - """Custom field not present in the original OpenAI definition. When set, + """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat - Completion endpoint. + Completion endpoint. """ # Flattened stream option to simplify form data. stream_include_usage: Optional[bool] = False @@ -1642,7 +1651,7 @@ class TranscriptionRequest(OpenAIBaseModel): """ top_p: Optional[float] = None - """Enables nucleus (top-p) sampling, where tokens are selected from the + """Enables nucleus (top-p) sampling, where tokens are selected from the smallest possible set whose cumulative probability exceeds `p`. """ @@ -1650,7 +1659,7 @@ class TranscriptionRequest(OpenAIBaseModel): """Limits sampling to the `k` most probable tokens at each step.""" min_p: Optional[float] = None - """Filters out tokens with a probability lower than `min_p`, ensuring a + """Filters out tokens with a probability lower than `min_p`, ensuring a minimum likelihood threshold during sampling. """ diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1067f35ce240..72086a11bf7e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -131,7 +131,9 @@ async def create_completion( for i, engine_prompt in enumerate(engine_prompts): sampling_params: Union[SamplingParams, BeamSearchParams] default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) + engine_prompt.get("prompt_token_ids", []) + or engine_prompt.get("prompt_embeds", [])) + if request.use_beam_search: sampling_params = request.to_beam_search_params( default_max_tokens, self.default_sampling_params) @@ -211,7 +213,7 @@ async def create_completion( # We did not pass it into vLLM engine to avoid being redundant # with the inputs token IDs if final_res.prompt is None: - final_res.prompt = request_prompts[i]["prompt"] + final_res.prompt = request_prompts[i].get("prompt") final_res_batch_checked = cast(list[RequestOutput], final_res_batch) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 49b346a23baf..6a2b6665682f 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 - +import base64 +import io import json from collections.abc import Iterable, Iterator, Mapping, Sequence from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus from typing import Annotated, Any, Callable, Optional, TypedDict, Union +import torch from fastapi import Request from pydantic import Field from starlette.datastructures import Headers @@ -67,7 +69,11 @@ class TextTokensPrompt(TypedDict): prompt_token_ids: list[int] -RequestPrompt = Union[list[int], str, TextTokensPrompt] +class EmbedsPrompt(TypedDict): + prompt_embeds: torch.Tensor + + +RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt] class OpenAIServing: @@ -320,10 +326,11 @@ def _tokenize_prompt_input_or_inputs( self, request: AnyRequest, tokenizer: AnyTokenizer, - input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> list[TextTokensPrompt]: + ) -> Union[list[TextTokensPrompt], list[EmbedsPrompt]]: """ Tokenize/detokenize depending on the input format. @@ -335,21 +342,27 @@ def _tokenize_prompt_input_or_inputs( # VSCode Pyright extension should still work properly # "is True" is required for Pyright to perform type narrowing # See: https://github.com/microsoft/pyright/issues/7672 - return [ - self._normalize_prompt_text_to_input( - request, - tokenizer, - prompt=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens) - if prompt_input["is_tokens"] is False else - self._normalize_prompt_tokens_to_input( - request, - tokenizer, - prompt_ids=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens) - for prompt_input in parse_and_batch_prompt(input_or_inputs) - ] + request_prompts = [] + if input_or_inputs: + request_prompts.extend([ + self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens) + if prompt_input["is_tokens"] is False else + self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens) + for prompt_input in parse_and_batch_prompt(input_or_inputs) + ]) + request_prompts.extend( + self._load_prompt_embeds(request.prompt_embeds, + truncate_prompt_tokens)) + return request_prompts async def _preprocess_completion( self, @@ -358,7 +371,8 @@ async def _preprocess_completion( input_or_inputs: Union[str, list[str], list[int], list[list[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]: + ) -> tuple[Union[list[TextTokensPrompt], list[EmbedsPrompt]], Union[ + list[TokensPrompt], list[EmbedsPrompt]]]: request_prompts = await self._tokenize_prompt_input_or_inputs_async( request, tokenizer, @@ -368,6 +382,7 @@ async def _preprocess_completion( ) engine_prompts = [ + request_prompt if "prompt_embeds" in request_prompt else TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) for request_prompt in request_prompts ] @@ -472,6 +487,34 @@ async def _preprocess_chat( return conversation, [request_prompt], [engine_prompt] + def _load_prompt_embeds( + self, + prompt_embeds: Optional[Union[bytes, list[bytes]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + ) -> list[EmbedsPrompt]: + + def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: + tensor = torch.load(io.BytesIO(base64.b64decode(embed))) + assert isinstance( + tensor, + (torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor)) + if tensor.dim() > 2: + tensor = tensor.squeeze(0) + assert tensor.dim() == 2 + if truncate_prompt_tokens is not None: + tensor = tensor[-truncate_prompt_tokens:] + return {"prompt_embeds": tensor} + + if prompt_embeds: + if isinstance(prompt_embeds, list): + return [ + _load_and_validate_embed(embed) for embed in prompt_embeds + ] + else: + return [_load_and_validate_embed(prompt_embeds)] + else: + return [] + def _log_inputs( self, request_id: str, @@ -483,13 +526,13 @@ def _log_inputs( ) -> None: if self.request_logger is None: return - + prompt, prompt_token_ids, prompt_embeds = None, None, None if isinstance(inputs, str): prompt = inputs - prompt_token_ids = None elif isinstance(inputs, list): - prompt = None prompt_token_ids = inputs + elif 'prompt_embeds' in inputs: + prompt_embeds = inputs.get("prompt_embeds") else: prompt = inputs["prompt"] prompt_token_ids = inputs["prompt_token_ids"] @@ -498,6 +541,7 @@ def _log_inputs( request_id, prompt, prompt_token_ids, + prompt_embeds, params=params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, From 85642d0c7816751b348ba4a57dee41c696224439 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 1 May 2025 16:09:19 -0500 Subject: [PATCH 69/96] support encoder decoder models with inputs_embeds Signed-off-by: Andrew Sansom --- vllm/worker/enc_dec_model_runner.py | 2 ++ vllm/worker/model_runner.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index c7ecc75b387c..4864163b0de2 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -49,6 +49,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "encoder_input_tokens": self.encoder_input_tokens, "encoder_input_positions": self.encoder_input_positions, @@ -196,6 +197,7 @@ def execute_model( model_input.virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, encoder_input_ids=model_input.encoder_input_tokens, encoder_positions=model_input.encoder_input_positions, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9f6a7b382d48..8a16fd6a9abd 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1646,7 +1646,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][( - batch_size, use_inputs_embeds)] = (graph_runner) + batch_size, use_inputs_embeds)] = graph_runner if self.lora_config: self._remove_dummy_loras() From b226fd6be57e4c011c34906c6b975b57eeadb97e Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 1 May 2025 16:47:33 -0500 Subject: [PATCH 70/96] simplify redundant ternary statement Signed-off-by: Andrew Sansom --- vllm/worker/model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8a16fd6a9abd..85814e9af9e3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1831,8 +1831,7 @@ def execute_model( self.vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, - inputs_embeds=model_input.inputs_embeds - if model_input.inputs_embeds is not None else None, + inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, From b738d3fdb93b81a1f86f8fe1c618d8e82a06481e Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 1 May 2025 16:48:06 -0500 Subject: [PATCH 71/96] explicitly remove support for inputs embeds with speculative decoding models Signed-off-by: Andrew Sansom --- vllm/spec_decode/draft_model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 302bb322519a..a6276c563394 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -201,6 +201,9 @@ def execute_model( if self.prompt_adapter_config is not None: raise ValueError("TP1DraftModelRunner has no support for " "prompt_adapter_config") + if model_input.inputs_embeds is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "inputs_embeds") if model_input.multi_modal_kwargs: raise ValueError( "TP1DraftModelRunner has no support for multi_modal_kwargs" @@ -288,6 +291,7 @@ def execute_model( self.vllm_config): hidden_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=None, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, From 2340119d4d8faa0c4de43fe58c89e80767b94a9f Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 1 May 2025 17:04:34 -0500 Subject: [PATCH 72/96] fix occasional device mismatch errors when appending output tokens to a sample outputs Signed-off-by: Andrew Sansom --- vllm/sequence.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 8c4e7beae588..5bc9b8a6fc82 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -342,7 +342,9 @@ def append_token_id(self, (self._output_embeds, token_embed), dim=0) assert self._cached_all_token_embeds is not None self._cached_all_token_embeds = torch.cat( - (self._cached_all_token_embeds, token_embed), dim=0) + (self._cached_all_token_embeds, + token_embed.to(device=self._cached_all_token_embeds.device)), + dim=0) def get_len(self) -> int: return len(self._output_token_ids) + len(self._prompt_token_ids) From ab5ea30b7b94e0361e432a0bc333ccb6680b8d6c Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Fri, 9 May 2025 11:03:39 -0500 Subject: [PATCH 73/96] fix typing Signed-off-by: Nan2018 --- vllm/entrypoints/openai/serving_completion.py | 11 ++++++----- vllm/entrypoints/openai/serving_engine.py | 2 +- vllm/entrypoints/openai/serving_tokenization.py | 5 +++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 72086a11bf7e..485940a5b673 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -130,9 +130,10 @@ async def create_completion( try: for i, engine_prompt in enumerate(engine_prompts): sampling_params: Union[SamplingParams, BeamSearchParams] - default_max_tokens = self.max_model_len - len( - engine_prompt.get("prompt_token_ids", []) - or engine_prompt.get("prompt_embeds", [])) + prompt_token_ids = engine_prompt.get("prompt_token_ids", []) + prompt_embeds = engine_prompt.get("prompt_embeds", []) + default_max_tokens = self.max_model_len - len(prompt_token_ids + or prompt_embeds) if request.use_beam_search: sampling_params = request.to_beam_search_params( @@ -278,8 +279,8 @@ async def completion_stream_generator( prompt_text = res.prompt # Prompt details are excluded from later streamed outputs - if res.prompt_token_ids is not None: - num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) + if prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(prompt_token_ids) delta_token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[dict[ diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 24d22bc39260..cd82dc25de78 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -381,7 +381,7 @@ async def _preprocess_completion( add_special_tokens=add_special_tokens, ) - engine_prompts = [ + engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]] = [ request_prompt if "prompt_embeds" in request_prompt else TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) for request_prompt in request_prompts diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index c642fc51005e..a43cfea82e6d 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -103,8 +103,9 @@ async def create_tokenize( # Silently ignore prompt adapter since it does not affect # tokenization (Unlike in Embeddings API where an error is raised) - - input_ids.extend(engine_prompt["prompt_token_ids"]) + if isinstance(engine_prompt, + dict) and "prompt_token_ids" in engine_prompt: + input_ids.extend(engine_prompt["prompt_token_ids"]) return TokenizeResponse(tokens=input_ids, count=len(input_ids), From 2c2dc0aaae39cde9bcf08963bdd4ec229ed5aaf7 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Fri, 9 May 2025 11:35:21 -0500 Subject: [PATCH 74/96] torch load weights only; raise error if prompt embeds and lora or prompt adapter are used together Signed-off-by: Nan2018 --- vllm/entrypoints/openai/cli_args.py | 8 ++++++++ vllm/entrypoints/openai/serving_engine.py | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a2639d374791..d0cbb3c7f1ca 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -283,6 +283,14 @@ def validate_parsed_serve_args(args: argparse.Namespace): if args.enable_auto_tool_choice and not args.tool_call_parser: raise TypeError("Error: --enable-auto-tool-choice requires " "--tool-call-parser") + if args.enable_prompt_embeds: + if args.enable_lora: + raise ValueError( + "Cannot use prompt embeds and lora at the same time.") + if args.enable_prompt_adapter: + raise ValueError( + "Cannot use prompt embeds and prompt adapter at the same time." + ) def create_parser_for_docs() -> FlexibleArgumentParser: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index cd82dc25de78..221942586f9b 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -497,7 +497,8 @@ def _load_prompt_embeds( ) -> list[EmbedsPrompt]: def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: - tensor = torch.load(io.BytesIO(base64.b64decode(embed))) + tensor = torch.load(io.BytesIO(base64.b64decode(embed)), + weights_only=True) assert isinstance( tensor, (torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor)) From 6147e3cdbccabc119fb28ab9fac7cebd8bab6fba Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 9 May 2025 15:09:37 -0500 Subject: [PATCH 75/96] refactor to resolve type errors in serving_completion Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/serving_completion.py | 33 +++++++++-- vllm/entrypoints/openai/serving_engine.py | 55 +++++++++++++++---- vllm/inputs/data.py | 13 ++++- 3 files changed, 82 insertions(+), 19 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 485940a5b673..891ca3128819 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,6 +8,7 @@ import jinja2 from fastapi import Request +from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -25,8 +26,11 @@ UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - clamp_prompt_logprobs) + clamp_prompt_logprobs, + is_text_tokens_prompt) from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, + is_tokens_prompt) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -130,10 +134,23 @@ async def create_completion( try: for i, engine_prompt in enumerate(engine_prompts): sampling_params: Union[SamplingParams, BeamSearchParams] - prompt_token_ids = engine_prompt.get("prompt_token_ids", []) - prompt_embeds = engine_prompt.get("prompt_embeds", []) - default_max_tokens = self.max_model_len - len(prompt_token_ids - or prompt_embeds) + # Mypy does not infer that engine_prompt will have only one of + # "prompt_token_ids" or "prompt_embeds" defined, and both of + # these as Union[object, the expected type], where it infers + # object if engine_prompt is a subclass of one of the + # typeddicts that defines both keys. Worse, because of + # https://github.com/python/mypy/issues/8586, mypy does not + # infer the type of engine_prompt correctly because of the + # enumerate. So we need an unnecessary cast here. + engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], + engine_prompt) + if is_embeds_prompt(engine_prompt): + input_length = len(engine_prompt["prompt_embeds"]) + elif is_tokens_prompt(engine_prompt): + input_length = len(engine_prompt["prompt_token_ids"]) + else: + assert_never(engine_prompt) + default_max_tokens = self.max_model_len - input_length if request.use_beam_search: sampling_params = request.to_beam_search_params( @@ -214,7 +231,11 @@ async def create_completion( # We did not pass it into vLLM engine to avoid being redundant # with the inputs token IDs if final_res.prompt is None: - final_res.prompt = request_prompts[i].get("prompt") + request_prompt = request_prompts[i] + if is_text_tokens_prompt(request_prompt): + final_res.prompt = request_prompt["prompt"] + else: + final_res.prompt = None final_res_batch_checked = cast(list[RequestOutput], final_res_batch) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 221942586f9b..ba3c4fe59b93 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -5,12 +5,13 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus -from typing import Annotated, Any, Callable, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Optional, TypedDict, Union, cast import torch from fastapi import Request from pydantic import Field from starlette.datastructures import Headers +from typing_extensions import TypeIs, assert_never import vllm.envs as envs from vllm.config import ModelConfig @@ -38,7 +39,8 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable -from vllm.inputs import TokensPrompt +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -76,6 +78,16 @@ class EmbedsPrompt(TypedDict): RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt] +def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + +def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt) + + class OpenAIServing: def __init__( @@ -368,11 +380,12 @@ async def _preprocess_completion( self, request: CompletionLikeRequest, tokenizer: AnyTokenizer, - input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> tuple[Union[list[TextTokensPrompt], list[EmbedsPrompt]], Union[ - list[TokensPrompt], list[EmbedsPrompt]]]: + list[EngineTokensPrompt], list[EngineEmbedsPrompt]]]: request_prompts = await self._tokenize_prompt_input_or_inputs_async( request, tokenizer, @@ -381,13 +394,31 @@ async def _preprocess_completion( add_special_tokens=add_special_tokens, ) - engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]] = [ - request_prompt if "prompt_embeds" in request_prompt else - TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) - for request_prompt in request_prompts - ] + # This list has a weird type hint. This should be a union of lists, not + # a list of unions, but mypy does not play nicely with the variance of + # typeddicts, and there's no way to easily assert which list type we + # received from request_prompts. So the compromise is we type the list + # as a union as we build it, and then cast it to the correct type when + # we return. + # This casting is safe because we know the list can only be full of + # either EngineTokensPrompt or EngineEmbedsPrompt, but not a mix. + engine_prompts: list[Union[EngineTokensPrompt, + EngineEmbedsPrompt]] = [] + for request_prompt in request_prompts: + if is_embeds_prompt(request_prompt): + engine_prompts.append( + EngineEmbedsPrompt( + prompt_embeds=request_prompt["prompt_embeds"])) + elif is_text_tokens_prompt(request_prompt): + engine_prompts.append( + EngineTokensPrompt( + prompt_token_ids=request_prompt["prompt_token_ids"])) + else: + assert_never(request_prompt) - return request_prompts, engine_prompts + return request_prompts, cast( + Union[list[EngineTokensPrompt], list[EngineEmbedsPrompt]], + engine_prompts) async def _preprocess_chat( self, @@ -405,7 +436,7 @@ async def _preprocess_chat( truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = False, ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], - list[TokensPrompt]]: + list[EngineTokensPrompt]]: model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( @@ -478,7 +509,7 @@ async def _preprocess_chat( prompt=tokenizer.decode(request_prompt), prompt_token_ids=request_prompt) - engine_prompt = TokensPrompt( + engine_prompt = EngineTokensPrompt( prompt_token_ids=prompt_inputs["prompt_token_ids"]) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index c83ab73b614a..3b58ec47d5bf 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast import torch -from typing_extensions import NotRequired, TypedDict, TypeVar +from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar if TYPE_CHECKING: from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs @@ -98,6 +98,17 @@ class EmbedsPrompt(TypedDict): more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt` """ + +def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + +def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt) + + _T1_co = TypeVar("_T1_co", bound=SingletonPrompt, default=SingletonPrompt, From 61d26419e4d32a0616111b3821d90315eb20c81a Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 9 May 2025 16:40:24 -0500 Subject: [PATCH 76/96] refactor to resolve type errors in serving_engine.py Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/serving_engine.py | 57 +++++++++++++++++------ 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ba3c4fe59b93..8f29a6a8c34b 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -5,7 +5,8 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus -from typing import Annotated, Any, Callable, Optional, TypedDict, Union, cast +from typing import (Annotated, Any, Callable, Optional, TypedDict, Union, cast, + overload) import torch from fastapi import Request @@ -350,13 +351,13 @@ def _tokenize_prompt_input_or_inputs( , each input can be a string or array of tokens. Note that each request can pass one or more inputs. """ - # Although our type checking is based on mypy, - # VSCode Pyright extension should still work properly - # "is True" is required for Pyright to perform type narrowing - # See: https://github.com/microsoft/pyright/issues/7672 - request_prompts = [] - if input_or_inputs: - request_prompts.extend([ + + if input_or_inputs is not None: + # Although our type checking is based on mypy, + # VSCode Pyright extension should still work properly + # "is False" is required for Pyright to perform type narrowing + # See: https://github.com/microsoft/pyright/issues/7672 + return [ self._normalize_prompt_text_to_input( request, tokenizer, @@ -370,11 +371,35 @@ def _tokenize_prompt_input_or_inputs( prompt_ids=prompt_input["content"], truncate_prompt_tokens=truncate_prompt_tokens) for prompt_input in parse_and_batch_prompt(input_or_inputs) - ]) - request_prompts.extend( - self._load_prompt_embeds(request.prompt_embeds, - truncate_prompt_tokens)) - return request_prompts + ] + if not isinstance(request, CompletionRequest): + raise ValueError( + "Using prompt embeddings with any request other than a" + " CompletionRequest is not supported.") + return self._load_prompt_embeds(request.prompt_embeds, + truncate_prompt_tokens) + + @overload + async def _preprocess_completion( + self, + request: CompletionLikeRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., + add_special_tokens: bool = ..., + ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: + ... + + @overload + async def _preprocess_completion( + self, + request: CompletionRequest, + tokenizer: AnyTokenizer, + input_or_inputs: None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., + add_special_tokens: bool = ..., + ) -> tuple[list[EmbedsPrompt], list[EngineEmbedsPrompt]]: + ... async def _preprocess_completion( self, @@ -386,6 +411,12 @@ async def _preprocess_completion( add_special_tokens: bool = True, ) -> tuple[Union[list[TextTokensPrompt], list[EmbedsPrompt]], Union[ list[EngineTokensPrompt], list[EngineEmbedsPrompt]]]: + if not isinstance(request, + CompletionRequest) and input_or_inputs is None: + raise ValueError( + "Prompt embeds with non-completion requests is not" + " currently supported.") + request_prompts = await self._tokenize_prompt_input_or_inputs_async( request, tokenizer, From 4af2b64c7d772b214be0a1aaa23e79bb29009516 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 9 May 2025 16:53:58 -0500 Subject: [PATCH 77/96] serving completions typing Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/serving_completion.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 891ca3128819..43065dda10b1 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -172,6 +172,11 @@ async def create_completion( trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) + # Mypy inconsistently requires this second cast in different + # environments. It shouldn't be necessary (redundant from above) + # but pre-commit in CI fails without it. + engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], + engine_prompt) if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, From 27ed406a377bb73b04323bdaa65052667e35e745 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 12 May 2025 13:34:47 -0500 Subject: [PATCH 78/96] prefer prompt embeds for completion requests when available Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_completion.py | 9 ++-- vllm/entrypoints/openai/serving_completion.py | 2 +- vllm/entrypoints/openai/serving_engine.py | 53 +++++++++---------- 3 files changed, 31 insertions(+), 33 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 197be1decfb6..1918086ff500 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -431,17 +431,18 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, temperature=0.0, stream=True, extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) - chunks = [[], []] + chunks_stream_embeds: list[list[str]] = [[], []] finish_reason_count = 0 async for chunk in stream: - chunks[chunk.choices[0].index].append(chunk.choices[0].text) + chunks_stream_embeds[chunk.choices[0].index].append( + chunk.choices[0].text) if chunk.choices[0].finish_reason is not None: finish_reason_count += 1 assert finish_reason_count == 2 assert chunk.choices[0].finish_reason == "length" assert chunk.choices[0].text - assert len(chunks[0]) > 0 - assert len(chunks[1]) > 0 + assert len(chunks_stream_embeds[0]) > 0 + assert len(chunks_stream_embeds[1]) > 0 @pytest.mark.asyncio diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 43065dda10b1..da25ac257915 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -112,7 +112,7 @@ async def create_completion( request_prompts, engine_prompts = await self._preprocess_completion( request, tokenizer, - request.prompt, + request.prompt or None, truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8f29a6a8c34b..52654fcad55f 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -351,33 +351,31 @@ def _tokenize_prompt_input_or_inputs( , each input can be a string or array of tokens. Note that each request can pass one or more inputs. """ - - if input_or_inputs is not None: - # Although our type checking is based on mypy, - # VSCode Pyright extension should still work properly - # "is False" is required for Pyright to perform type narrowing - # See: https://github.com/microsoft/pyright/issues/7672 - return [ - self._normalize_prompt_text_to_input( - request, - tokenizer, - prompt=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens) - if prompt_input["is_tokens"] is False else - self._normalize_prompt_tokens_to_input( - request, - tokenizer, - prompt_ids=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens) - for prompt_input in parse_and_batch_prompt(input_or_inputs) - ] - if not isinstance(request, CompletionRequest): - raise ValueError( - "Using prompt embeddings with any request other than a" - " CompletionRequest is not supported.") - return self._load_prompt_embeds(request.prompt_embeds, - truncate_prompt_tokens) + # We want to always ignore text prompts if prompt embeds are available + if (isinstance(request, CompletionRequest) + and request.prompt_embeds is not None): + return self._load_prompt_embeds(request.prompt_embeds, + truncate_prompt_tokens) + + # Although our type checking is based on mypy, + # VSCode Pyright extension should still work properly + # "is False" is required for Pyright to perform type narrowing + # See: https://github.com/microsoft/pyright/issues/7672 + return [ + self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens) + if prompt_input["is_tokens"] is False else + self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens) + for prompt_input in parse_and_batch_prompt(input_or_inputs) + ] @overload async def _preprocess_completion( @@ -416,7 +414,6 @@ async def _preprocess_completion( raise ValueError( "Prompt embeds with non-completion requests is not" " currently supported.") - request_prompts = await self._tokenize_prompt_input_or_inputs_async( request, tokenizer, From 72e124433bb5877817976c5d258a12916d577b44 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 12 May 2025 15:35:59 -0500 Subject: [PATCH 79/96] explicitly do not support echo and prompt embeds Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/serving_completion.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index da25ac257915..5b0a4f1b6048 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -94,6 +94,10 @@ async def create_completion( return self.create_error_response( "suffix is not currently supported") + if request.echo and request.prompt_embeds is not None: + return self.create_error_response( + "Echo is unsupported with prompt embeds.") + request_id = f"cmpl-{self._base_request_id(raw_request)}" created_time = int(time.time()) From db00178e8b1a6e6f3b6f0623f1892dbd9c01272a Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 12 May 2025 15:36:36 -0500 Subject: [PATCH 80/96] refactor tests for completions endpoints with prompt embeds to require v0 Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_completion.py | 299 +++++++++++--------- 1 file changed, 170 insertions(+), 129 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 1918086ff500..0b062c2e82a0 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -6,6 +6,7 @@ import json import re import shutil +from collections.abc import Sequence from tempfile import TemporaryDirectory from typing import Optional @@ -64,8 +65,7 @@ def zephyr_pa_files(): @pytest.fixture(scope="module") -def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, - zephyr_pa_files): +def common_server_args() -> list[str]: return [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -75,6 +75,14 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, "--max-num-seqs", "128", "--enforce-eager", + ] + + +@pytest.fixture(scope="module") +def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, + zephyr_pa_files, common_server_args: Sequence[str]): + return [ + *common_server_args, # lora config "--enable-lora", "--lora-modules", @@ -105,12 +113,39 @@ def server(default_server_args, request): yield remote_server +@pytest.fixture(scope='module') +def monkeymodule(): + with pytest.MonkeyPatch.context() as mp: + yield mp + + +@pytest.fixture(scope="module", + params=["", "--disable-frontend-multiprocessing"]) +def server_with_prompt_embeds(common_server_args, request, monkeymodule): + if request.param: + common_server_args.append(request.param) + + # Prompt embeds are currently only supported on v0 + monkeymodule.setenv("VLLM_USE_V1", "0") + # We use the common server args instead of the default server args because + # prompt embeds are not compatible with Lora or Prompt Adapter requests + common_server_args.append("--enable-prompt-embeds") + with RemoteOpenAIServer(MODEL_NAME, common_server_args) as remote_server: + yield remote_server + + @pytest_asyncio.fixture async def client(server): async with server.get_async_client() as async_client: yield async_client +@pytest_asyncio.fixture +async def client_with_prompt_embeds(server_with_prompt_embeds): + async with server_with_prompt_embeds.get_async_client() as async_client: + yield async_client + + def create_dummy_embeds(num_tokens: int = 5) -> str: """Create dummy embeddings and return them as base64 encoded string.""" dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size) @@ -155,45 +190,6 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, assert len(completion.choices[0].text) >= 1 assert completion.choices[0].prompt_logprobs is None - # test using prompt_embeds - encoded_embeds = create_dummy_embeds() - completion = await client.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) - assert len(completion.choices[0].text) >= 1 - assert completion.choices[0].prompt_logprobs is None - - # test batch completion with prompt_embeds - encoded_embeds2 = create_dummy_embeds() - completion = await client.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) - assert len(completion.choices) == 2 - assert len(completion.choices[0].text) >= 1 - assert len(completion.choices[1].text) >= 1 - - # test error case: neither prompt nor prompt_embeds provided - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - max_tokens=5, - temperature=0.0, - ) - - # test error case: invalid prompt_embeds - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - max_tokens=5, - temperature=0.0, - extra_body={"prompt_embeds": "invalid_base64"}) - @pytest.mark.asyncio async def test_added_lora_tokens(client: openai.AsyncOpenAI): @@ -394,56 +390,6 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, assert chunk.choices[0].text assert "".join(chunks) == single_output - # test streaming with prompt_embeds - encoded_embeds = create_dummy_embeds() - single_completion = await client.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) - single_output = single_completion.choices[0].text - - stream = await client.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - stream=True, - extra_body={"prompt_embeds": encoded_embeds}) - chunks = [] - finish_reason_count = 0 - async for chunk in stream: - chunks.append(chunk.choices[0].text) - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert finish_reason_count == 1 - assert chunk.choices[0].finish_reason == "length" - assert chunk.choices[0].text - assert "".join(chunks) == single_output - - # test batch streaming with prompt_embeds - encoded_embeds2 = create_dummy_embeds() - stream = await client.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - stream=True, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) - chunks_stream_embeds: list[list[str]] = [[], []] - finish_reason_count = 0 - async for chunk in stream: - chunks_stream_embeds[chunk.choices[0].index].append( - chunk.choices[0].text) - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert finish_reason_count == 2 - assert chunk.choices[0].finish_reason == "length" - assert chunk.choices[0].text - assert len(chunks_stream_embeds[0]) > 0 - assert len(chunks_stream_embeds[1]) > 0 - @pytest.mark.asyncio @pytest.mark.parametrize( @@ -861,36 +807,153 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) > 5 - # test using prompt_embeds + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema, sample_regex): + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example JSON that fits this schema: 42", + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) + + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example string that fits this regex", + extra_body=dict(guided_regex=sample_regex, + guided_json=sample_json_schema)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test case: Single prompt embeds input encoded_embeds = create_dummy_embeds() - completion = await client.completions.create( + completion = await client_with_prompt_embeds.completions.create( model=model_name, prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - echo=True, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + # Test case: batch completion with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + assert len(completion.choices) == 2 + assert len(completion.choices[0].text) >= 1 + assert len(completion.choices[1].text) >= 1 + + # Test case: streaming with prompt_embeds + encoded_embeds = create_dummy_embeds() + single_completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + single_output = single_completion.choices[0].text + + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": encoded_embeds}) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + # Test case: batch streaming with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + chunks_stream_embeds: list[list[str]] = [[], []] + finish_reason_count = 0 + async for chunk in stream: + chunks_stream_embeds[chunk.choices[0].index].append( + chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 2 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert len(chunks_stream_embeds[0]) > 0 + assert len(chunks_stream_embeds[1]) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_errors_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test error case: invalid prompt_embeds + with pytest.raises(BadRequestError): + await client_with_prompt_embeds.completions.create( + prompt="", + model=model_name, + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": "invalid_base64"}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_logprobs_and_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int, + model_name: str): + # Test case: Logprobs using prompt_embeds + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, logprobs=logprobs_arg, extra_body={"prompt_embeds": encoded_embeds}) logprobs = completion.choices[0].logprobs assert logprobs is not None - assert len(logprobs.text_offset) > 5 - assert (len(logprobs.token_logprobs) > 5 - and logprobs.token_logprobs[0] is None) - assert (len(logprobs.top_logprobs) > 5 - and logprobs.top_logprobs[0] is None) + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 for top_logprobs in logprobs.top_logprobs[1:]: assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 - assert len(logprobs.tokens) > 5 + assert len(logprobs.tokens) == 5 - # test batch completion with prompt_embeds + # Test case: Log probs with batch completion and prompt_embeds encoded_embeds2 = create_dummy_embeds() - completion = await client.completions.create( + completion = await client_with_prompt_embeds.completions.create( model=model_name, prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - echo=True, + echo=False, logprobs=logprobs_arg, extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) @@ -898,32 +961,10 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI, for choice in completion.choices: logprobs = choice.logprobs assert logprobs is not None - assert len(logprobs.text_offset) > 5 - assert (len(logprobs.token_logprobs) > 5 - and logprobs.token_logprobs[0] is None) - assert (len(logprobs.top_logprobs) > 5 - and logprobs.top_logprobs[0] is None) + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 for top_logprobs in logprobs.top_logprobs[1:]: assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 - assert len(logprobs.tokens) > 5 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema, sample_regex): - with pytest.raises(openai.BadRequestError): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42, - guided_decoding_backend=guided_decoding_backend)) - - with pytest.raises(openai.BadRequestError): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example string that fits this regex", - extra_body=dict(guided_regex=sample_regex, - guided_json=sample_json_schema)) + assert len(logprobs.tokens) == 5 From 318ee3f5bc86062acd466fed46db4c0d5fc98eac Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 12 May 2025 15:59:39 -0500 Subject: [PATCH 81/96] style Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/serving_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 866bfe4c578c..6b473ec6ee28 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -136,9 +136,9 @@ class RequestProcessingMixin(BaseModel): """ request_prompts: Optional[Sequence[RequestPrompt]] = \ Field(default_factory=list) - engine_prompts: Optional[\ - list[EngineTokensPrompt] | list[EngineEmbedsPrompt]] = \ - Field(default_factory=list) + engine_prompts: Optional[Union[list[EngineTokensPrompt], + list[EngineEmbedsPrompt]]] = Field( + default_factory=list) model_config = ConfigDict(arbitrary_types_allowed=True) From 78754b0bc4885c15a1f9591208e1ff28870c6fd0 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 12 May 2025 16:03:16 -0500 Subject: [PATCH 82/96] add None check Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/serving_engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 6b473ec6ee28..4ef628096161 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -639,6 +639,11 @@ def _tokenize_prompt_input_or_inputs( return self._load_prompt_embeds(request.prompt_embeds, truncate_prompt_tokens) + if input_or_inputs is None: + raise ValueError( + "Prompt can only be None for a `/v1/completions` request" + " with prompt_embeds") + # Although our type checking is based on mypy, # VSCode Pyright extension should still work properly # "is False" is required for Pyright to perform type narrowing From 719168d5208d66554c2cd9c2ed429983a63b5eeb Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 12 May 2025 16:09:59 -0500 Subject: [PATCH 83/96] appease mypy Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/serving_engine.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 4ef628096161..8249b2971346 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -336,6 +336,12 @@ async def _prepare_generators( lora_request=ctx.lora_request, prompt_adapter_request=ctx.prompt_adapter_request) + # Mypy has an existing bug related to inferring the variance of + # TypedDicts with `builtins.enumerate`: + # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 + engine_prompt = cast( + Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) generator = self.engine_client.encode( engine_prompt, pooling_params, From 1ea957ecfffb2e764dac59328a13a85022a1498e Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 12 May 2025 21:34:00 -0500 Subject: [PATCH 84/96] pass in empty string prompts to preprocess to allow downstream handling Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/serving_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 565bf0151f63..7beaae287de9 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -116,7 +116,7 @@ async def create_completion( request_prompts, engine_prompts = await self._preprocess_completion( request, tokenizer, - request.prompt or None, + request.prompt, truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) From 03db71a1a0f6f849b3bcf3a62ceb9af2b5e1d8db Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 12 May 2025 21:34:45 -0500 Subject: [PATCH 85/96] re-add ability to allow model to be None in completion requests (accidentally reverted in bad merge commit) Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3dfaf25bdbfe..47c4581e5371 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -744,7 +744,7 @@ def check_cache_salt_support(cls, data): class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create - model: str + model: Optional[str] = None prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None prompt_embeds: Optional[Union[bytes, list[bytes]]] = None best_of: Optional[int] = None From c7122c41f344c9f98abe83d4efd221032feac1f8 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Mon, 12 May 2025 21:38:15 -0500 Subject: [PATCH 86/96] update type hint Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_completion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 0b062c2e82a0..4c19bfcebf34 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -6,7 +6,6 @@ import json import re import shutil -from collections.abc import Sequence from tempfile import TemporaryDirectory from typing import Optional @@ -80,7 +79,7 @@ def common_server_args() -> list[str]: @pytest.fixture(scope="module") def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, - zephyr_pa_files, common_server_args: Sequence[str]): + zephyr_pa_files, common_server_args: list[str]): return [ *common_server_args, # lora config From c0e064775a28821ec9d9d820c583bc561e7846ff Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 13 May 2025 11:12:21 -0500 Subject: [PATCH 87/96] pass in env_dict instead of failing to mock properly Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_completion.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 4c19bfcebf34..57e8b0411b3f 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -112,24 +112,20 @@ def server(default_server_args, request): yield remote_server -@pytest.fixture(scope='module') -def monkeymodule(): - with pytest.MonkeyPatch.context() as mp: - yield mp - - @pytest.fixture(scope="module", params=["", "--disable-frontend-multiprocessing"]) -def server_with_prompt_embeds(common_server_args, request, monkeymodule): +def server_with_prompt_embeds(common_server_args, request): if request.param: common_server_args.append(request.param) - # Prompt embeds are currently only supported on v0 - monkeymodule.setenv("VLLM_USE_V1", "0") # We use the common server args instead of the default server args because # prompt embeds are not compatible with Lora or Prompt Adapter requests common_server_args.append("--enable-prompt-embeds") - with RemoteOpenAIServer(MODEL_NAME, common_server_args) as remote_server: + + # Prompt embeds are currently only supported on v0 + with RemoteOpenAIServer(MODEL_NAME, + common_server_args, + env_dict={"VLLM_USE_V1": "0"}) as remote_server: yield remote_server From 56f10df8d7797ad7cef56cf67249f02c7aed9a28 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 13 May 2025 20:38:28 -0500 Subject: [PATCH 88/96] enable lora with prompt embeds Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/cli_args.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 80dda6299fc7..d01af5e42266 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -286,14 +286,9 @@ def validate_parsed_serve_args(args: argparse.Namespace): if args.enable_auto_tool_choice and not args.tool_call_parser: raise TypeError("Error: --enable-auto-tool-choice requires " "--tool-call-parser") - if args.enable_prompt_embeds: - if args.enable_lora: - raise ValueError( - "Cannot use prompt embeds and lora at the same time.") - if args.enable_prompt_adapter: - raise ValueError( - "Cannot use prompt embeds and prompt adapter at the same time." - ) + if args.enable_prompt_embeds and args.enable_prompt_adapter: + raise ValueError( + "Cannot use prompt embeds and prompt adapter at the same time.") def log_non_default_args(args: argparse.Namespace): From 92b336a318f972e1bf0c215924a6aec9c2bf904a Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 13 May 2025 20:46:16 -0500 Subject: [PATCH 89/96] disable chunked prefill in openai + prompt embeds checks Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_completion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 57e8b0411b3f..da28e33fc9bb 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -119,8 +119,9 @@ def server_with_prompt_embeds(common_server_args, request): common_server_args.append(request.param) # We use the common server args instead of the default server args because - # prompt embeds are not compatible with Lora or Prompt Adapter requests + # prompt embeds are not compatible with Prompt Adapter requests common_server_args.append("--enable-prompt-embeds") + common_server_args.append("--no-enable-chunked-prefill") # Prompt embeds are currently only supported on v0 with RemoteOpenAIServer(MODEL_NAME, From 72674e07fce603ded329ed69097acac652f54c3e Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 13 May 2025 21:33:17 -0500 Subject: [PATCH 90/96] move prompt embeds completions endpoint tests to their own file to avoid having two vLLM instances in memory at once Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_completion.py | 191 +------------- .../test_completion_with_prompt_embeds.py | 236 ++++++++++++++++++ 2 files changed, 239 insertions(+), 188 deletions(-) create mode 100644 tests/entrypoints/openai/test_completion_with_prompt_embeds.py diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index da28e33fc9bb..1d9aa4972b70 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # imports for guided decoding tests -import base64 -import io import json import re import shutil @@ -13,11 +11,10 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio -import torch # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoTokenizer from vllm.transformers_utils.tokenizer import get_tokenizer @@ -34,7 +31,6 @@ PA_NUM_VIRTUAL_TOKENS = 8 GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] -CONFIG = AutoConfig.from_pretrained(MODEL_NAME) @pytest.fixture(scope="module") @@ -64,7 +60,8 @@ def zephyr_pa_files(): @pytest.fixture(scope="module") -def common_server_args() -> list[str]: +def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, + zephyr_pa_files): return [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -74,14 +71,6 @@ def common_server_args() -> list[str]: "--max-num-seqs", "128", "--enforce-eager", - ] - - -@pytest.fixture(scope="module") -def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, - zephyr_pa_files, common_server_args: list[str]): - return [ - *common_server_args, # lora config "--enable-lora", "--lora-modules", @@ -112,44 +101,12 @@ def server(default_server_args, request): yield remote_server -@pytest.fixture(scope="module", - params=["", "--disable-frontend-multiprocessing"]) -def server_with_prompt_embeds(common_server_args, request): - if request.param: - common_server_args.append(request.param) - - # We use the common server args instead of the default server args because - # prompt embeds are not compatible with Prompt Adapter requests - common_server_args.append("--enable-prompt-embeds") - common_server_args.append("--no-enable-chunked-prefill") - - # Prompt embeds are currently only supported on v0 - with RemoteOpenAIServer(MODEL_NAME, - common_server_args, - env_dict={"VLLM_USE_V1": "0"}) as remote_server: - yield remote_server - - @pytest_asyncio.fixture async def client(server): async with server.get_async_client() as async_client: yield async_client -@pytest_asyncio.fixture -async def client_with_prompt_embeds(server_with_prompt_embeds): - async with server_with_prompt_embeds.get_async_client() as async_client: - yield async_client - - -def create_dummy_embeds(num_tokens: int = 5) -> str: - """Create dummy embeddings and return them as base64 encoded string.""" - dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size) - buffer = io.BytesIO() - torch.save(dummy_embeds, buffer) - return base64.b64encode(buffer.getvalue()).decode('utf-8') - - @pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras, then test prompt adapters @@ -822,145 +779,3 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, prompt="Give an example string that fits this regex", extra_body=dict(guided_regex=sample_regex, guided_json=sample_json_schema)) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_completions_with_prompt_embeds( - client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): - # Test case: Single prompt embeds input - encoded_embeds = create_dummy_embeds() - completion = await client_with_prompt_embeds.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) - assert len(completion.choices[0].text) >= 1 - assert completion.choices[0].prompt_logprobs is None - - # Test case: batch completion with prompt_embeds - encoded_embeds2 = create_dummy_embeds() - completion = await client_with_prompt_embeds.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) - assert len(completion.choices) == 2 - assert len(completion.choices[0].text) >= 1 - assert len(completion.choices[1].text) >= 1 - - # Test case: streaming with prompt_embeds - encoded_embeds = create_dummy_embeds() - single_completion = await client_with_prompt_embeds.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) - single_output = single_completion.choices[0].text - - stream = await client_with_prompt_embeds.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - stream=True, - extra_body={"prompt_embeds": encoded_embeds}) - chunks = [] - finish_reason_count = 0 - async for chunk in stream: - chunks.append(chunk.choices[0].text) - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert finish_reason_count == 1 - assert chunk.choices[0].finish_reason == "length" - assert chunk.choices[0].text - assert "".join(chunks) == single_output - - # Test case: batch streaming with prompt_embeds - encoded_embeds2 = create_dummy_embeds() - stream = await client_with_prompt_embeds.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - stream=True, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) - chunks_stream_embeds: list[list[str]] = [[], []] - finish_reason_count = 0 - async for chunk in stream: - chunks_stream_embeds[chunk.choices[0].index].append( - chunk.choices[0].text) - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert finish_reason_count == 2 - assert chunk.choices[0].finish_reason == "length" - assert chunk.choices[0].text - assert len(chunks_stream_embeds[0]) > 0 - assert len(chunks_stream_embeds[1]) > 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_completions_errors_with_prompt_embeds( - client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): - # Test error case: invalid prompt_embeds - with pytest.raises(BadRequestError): - await client_with_prompt_embeds.completions.create( - prompt="", - model=model_name, - max_tokens=5, - temperature=0.0, - extra_body={"prompt_embeds": "invalid_base64"}) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("logprobs_arg", [1, 0]) -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_completions_with_logprobs_and_prompt_embeds( - client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int, - model_name: str): - # Test case: Logprobs using prompt_embeds - encoded_embeds = create_dummy_embeds() - completion = await client_with_prompt_embeds.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - echo=False, - logprobs=logprobs_arg, - extra_body={"prompt_embeds": encoded_embeds}) - - logprobs = completion.choices[0].logprobs - assert logprobs is not None - assert len(logprobs.text_offset) == 5 - assert len(logprobs.token_logprobs) == 5 - assert len(logprobs.top_logprobs) == 5 - for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 - assert len(logprobs.tokens) == 5 - - # Test case: Log probs with batch completion and prompt_embeds - encoded_embeds2 = create_dummy_embeds() - completion = await client_with_prompt_embeds.completions.create( - model=model_name, - prompt="", # Add empty prompt as required parameter - max_tokens=5, - temperature=0.0, - echo=False, - logprobs=logprobs_arg, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) - - assert len(completion.choices) == 2 - for choice in completion.choices: - logprobs = choice.logprobs - assert logprobs is not None - assert len(logprobs.text_offset) == 5 - assert len(logprobs.token_logprobs) == 5 - assert len(logprobs.top_logprobs) == 5 - for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, - 1) <= len(top_logprobs) <= logprobs_arg + 1 - assert len(logprobs.tokens) == 5 diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py new file mode 100644 index 000000000000..b1015cd925a3 --- /dev/null +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 + +# imports for guided decoding tests +import base64 +import io +import shutil +from tempfile import TemporaryDirectory + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +import torch +# downloading lora to test lora requests +from huggingface_hub import snapshot_download +from openai import BadRequestError +from transformers import AutoConfig, AutoTokenizer + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +# technically these adapters use a different base model, +# but we're not testing generation quality here +LORA_NAME = "typeof/zephyr-7b-beta-lora" + +CONFIG = AutoConfig.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() + + +@pytest.fixture(scope="module") +def default_server_args( + zephyr_lora_files, + zephyr_lora_added_tokens_files, +) -> list[str]: + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # Prompt Embeds server args + "--enable-prompt-embeds", + "--no-enable-chunked-prefill", + ] + + +@pytest.fixture(scope="module", + params=["", "--disable-frontend-multiprocessing"]) +def server_with_prompt_embeds(default_server_args, request): + if request.param: + default_server_args.append(request.param) + + # Prompt embeds are currently only supported on v0 + with RemoteOpenAIServer(MODEL_NAME, + default_server_args, + env_dict={"VLLM_USE_V1": "0"}) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_with_prompt_embeds(server_with_prompt_embeds): + async with server_with_prompt_embeds.get_async_client() as async_client: + yield async_client + + +def create_dummy_embeds(num_tokens: int = 5) -> str: + """Create dummy embeddings and return them as base64 encoded string.""" + dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size) + buffer = io.BytesIO() + torch.save(dummy_embeds, buffer) + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test case: Single prompt embeds input + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + # Test case: batch completion with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + assert len(completion.choices) == 2 + assert len(completion.choices[0].text) >= 1 + assert len(completion.choices[1].text) >= 1 + + # Test case: streaming with prompt_embeds + encoded_embeds = create_dummy_embeds() + single_completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + single_output = single_completion.choices[0].text + + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": encoded_embeds}) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + # Test case: batch streaming with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + chunks_stream_embeds: list[list[str]] = [[], []] + finish_reason_count = 0 + async for chunk in stream: + chunks_stream_embeds[chunk.choices[0].index].append( + chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 2 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert len(chunks_stream_embeds[0]) > 0 + assert len(chunks_stream_embeds[1]) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_errors_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test error case: invalid prompt_embeds + with pytest.raises(BadRequestError): + await client_with_prompt_embeds.completions.create( + prompt="", + model=model_name, + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": "invalid_base64"}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_logprobs_and_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int, + model_name: str): + # Test case: Logprobs using prompt_embeds + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": encoded_embeds}) + + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 + + # Test case: Log probs with batch completion and prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + + assert len(completion.choices) == 2 + for choice in completion.choices: + logprobs = choice.logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 From 7134fe11418aea217b97cb0fe1593940a0b6979f Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 13 May 2025 22:12:46 -0500 Subject: [PATCH 91/96] allow mixed embeds/text prompts to completions endpoint Signed-off-by: Andrew Sansom --- .../test_completion_with_prompt_embeds.py | 27 +++++++++ vllm/entrypoints/openai/serving_engine.py | 60 +++++-------------- 2 files changed, 43 insertions(+), 44 deletions(-) diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index b1015cd925a3..341aeb479d99 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -171,6 +171,33 @@ async def test_completions_with_prompt_embeds( assert len(chunks_stream_embeds[0]) > 0 assert len(chunks_stream_embeds[1]) > 0 + # Test case: mixed text and prompt_embeds + encoded_embeds = create_dummy_embeds() + completion_mixed = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices) == 2 + completion_text_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + ) + completion_embeds_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + # Embeddings responses should be handled first + assert completion_mixed.choices[0].text == completion_embeds_only.choices[ + 0].text + assert completion_mixed.choices[1].text == completion_text_only.choices[ + 0].text + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8249b2971346..b0117d4e2550 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -9,7 +9,7 @@ from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, - TypeVar, Union, cast, overload) + TypeVar, Union, cast) import torch from fastapi import Request @@ -631,7 +631,7 @@ def _tokenize_prompt_input_or_inputs( list[list[int]]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> Union[list[TextTokensPrompt], list[EmbedsPrompt]]: + ) -> list[Union[TextTokensPrompt, EmbedsPrompt]]: """ Tokenize/detokenize depending on the input format. @@ -639,22 +639,24 @@ def _tokenize_prompt_input_or_inputs( , each input can be a string or array of tokens. Note that each request can pass one or more inputs. """ + inputs = list[Union[TextTokensPrompt, EmbedsPrompt]]() + # We want to always ignore text prompts if prompt embeds are available if (isinstance(request, CompletionRequest) and request.prompt_embeds is not None): - return self._load_prompt_embeds(request.prompt_embeds, - truncate_prompt_tokens) + inputs.extend( + self._load_prompt_embeds(request.prompt_embeds, + truncate_prompt_tokens)) - if input_or_inputs is None: - raise ValueError( - "Prompt can only be None for a `/v1/completions` request" - " with prompt_embeds") + # Empty prompts are okay as long as there are prompt embeddings + if input_or_inputs is None or (inputs and input_or_inputs == ""): + return inputs # Although our type checking is based on mypy, # VSCode Pyright extension should still work properly # "is False" is required for Pyright to perform type narrowing # See: https://github.com/microsoft/pyright/issues/7672 - return [ + inputs.extend([ self._normalize_prompt_text_to_input( request, tokenizer, @@ -668,29 +670,9 @@ def _tokenize_prompt_input_or_inputs( prompt_ids=prompt_input["content"], truncate_prompt_tokens=truncate_prompt_tokens) for prompt_input in parse_and_batch_prompt(input_or_inputs) - ] - - @overload - async def _preprocess_completion( - self, - request: CompletionLikeRequest, - tokenizer: AnyTokenizer, - input_or_inputs: Union[str, list[str], list[int], list[list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., - add_special_tokens: bool = ..., - ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: - ... + ]) - @overload - async def _preprocess_completion( - self, - request: CompletionRequest, - tokenizer: AnyTokenizer, - input_or_inputs: None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., - add_special_tokens: bool = ..., - ) -> tuple[list[EmbedsPrompt], list[EngineEmbedsPrompt]]: - ... + return inputs async def _preprocess_completion( self, @@ -700,8 +682,8 @@ async def _preprocess_completion( list[list[int]]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> tuple[Union[list[TextTokensPrompt], list[EmbedsPrompt]], Union[ - list[EngineTokensPrompt], list[EngineEmbedsPrompt]]]: + ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ + EngineTokensPrompt, EngineEmbedsPrompt]]]: if not isinstance(request, CompletionRequest) and input_or_inputs is None: raise ValueError( @@ -715,14 +697,6 @@ async def _preprocess_completion( add_special_tokens=add_special_tokens, ) - # This list has a weird type hint. This should be a union of lists, not - # a list of unions, but mypy does not play nicely with the variance of - # typeddicts, and there's no way to easily assert which list type we - # received from request_prompts. So the compromise is we type the list - # as a union as we build it, and then cast it to the correct type when - # we return. - # This casting is safe because we know the list can only be full of - # either EngineTokensPrompt or EngineEmbedsPrompt, but not a mix. engine_prompts: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = [] for request_prompt in request_prompts: @@ -737,9 +711,7 @@ async def _preprocess_completion( else: assert_never(request_prompt) - return request_prompts, cast( - Union[list[EngineTokensPrompt], list[EngineEmbedsPrompt]], - engine_prompts) + return request_prompts, engine_prompts async def _preprocess_chat( self, From 38c366d5227b17ba119d107f7ce10bffb84be727 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 13 May 2025 23:22:01 -0500 Subject: [PATCH 92/96] refactor serving engine to allow mixed embeds/text prompts to completion endpoint while remaining type safe for non-completions endpoints Signed-off-by: Andrew Sansom --- vllm/entrypoints/openai/serving_engine.py | 106 +++++++++++++++------- 1 file changed, 74 insertions(+), 32 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index b0117d4e2550..e4dc278800d9 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -9,13 +9,13 @@ from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, - TypeVar, Union, cast) + TypeVar, Union, cast, overload) import torch from fastapi import Request from pydantic import BaseModel, ConfigDict, Field from starlette.datastructures import Headers -from typing_extensions import TypeIs, assert_never +from typing_extensions import TypeIs if sys.version_info >= (3, 12): from typing import TypedDict @@ -631,7 +631,7 @@ def _tokenize_prompt_input_or_inputs( list[list[int]]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> list[Union[TextTokensPrompt, EmbedsPrompt]]: + ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]: """ Tokenize/detokenize depending on the input format. @@ -639,24 +639,25 @@ def _tokenize_prompt_input_or_inputs( , each input can be a string or array of tokens. Note that each request can pass one or more inputs. """ - inputs = list[Union[TextTokensPrompt, EmbedsPrompt]]() + inputs_embeds = list[EmbedsPrompt]() + inputs_text = list[TextTokensPrompt]() - # We want to always ignore text prompts if prompt embeds are available if (isinstance(request, CompletionRequest) and request.prompt_embeds is not None): - inputs.extend( + inputs_embeds.extend( self._load_prompt_embeds(request.prompt_embeds, truncate_prompt_tokens)) # Empty prompts are okay as long as there are prompt embeddings - if input_or_inputs is None or (inputs and input_or_inputs == ""): - return inputs + if input_or_inputs is None or (inputs_embeds + and input_or_inputs == ""): + return [], inputs_embeds # Although our type checking is based on mypy, # VSCode Pyright extension should still work properly # "is False" is required for Pyright to perform type narrowing # See: https://github.com/microsoft/pyright/issues/7672 - inputs.extend([ + inputs_text.extend([ self._normalize_prompt_text_to_input( request, tokenizer, @@ -672,7 +673,33 @@ def _tokenize_prompt_input_or_inputs( for prompt_input in parse_and_batch_prompt(input_or_inputs) ]) - return inputs + return inputs_text, inputs_embeds + + @overload + async def _preprocess_completion( + self, + request: Union[DetokenizeRequest, EmbeddingCompletionRequest, + RerankRequest, ClassificationRequest, ScoreRequest, + TokenizeCompletionRequest], + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., + add_special_tokens: bool = ..., + ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: + ... + + @overload + async def _preprocess_completion( + self, + request: CompletionRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., + add_special_tokens: bool = ..., + ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ + EngineTokensPrompt, EngineEmbedsPrompt]]]: + ... async def _preprocess_completion( self, @@ -682,35 +709,50 @@ async def _preprocess_completion( list[list[int]]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ - EngineTokensPrompt, EngineEmbedsPrompt]]]: + ) -> tuple[Union[list[TextTokensPrompt], list[Union[ + TextTokensPrompt, EmbedsPrompt]]], Union[ + list[EngineTokensPrompt], list[Union[EngineTokensPrompt, + EngineEmbedsPrompt]]]]: if not isinstance(request, CompletionRequest) and input_or_inputs is None: raise ValueError( "Prompt embeds with non-completion requests is not" " currently supported.") - request_prompts = await self._tokenize_prompt_input_or_inputs_async( - request, - tokenizer, - input_or_inputs, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens, - ) - engine_prompts: list[Union[EngineTokensPrompt, - EngineEmbedsPrompt]] = [] - for request_prompt in request_prompts: - if is_embeds_prompt(request_prompt): - engine_prompts.append( - EngineEmbedsPrompt( - prompt_embeds=request_prompt["prompt_embeds"])) - elif is_text_tokens_prompt(request_prompt): - engine_prompts.append( - EngineTokensPrompt( - prompt_token_ids=request_prompt["prompt_token_ids"])) - else: - assert_never(request_prompt) + (request_prompts_text, request_prompts_embeds + ) = await self._tokenize_prompt_input_or_inputs_async( + request, + tokenizer, + input_or_inputs, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + + engine_prompts_text = [ + EngineTokensPrompt( + prompt_token_ids=request_prompt_text["prompt_token_ids"]) + for request_prompt_text in request_prompts_text + ] + + # This check is equivalent to simply checking if + # `request_prompts_embeds` is empty, but it's difficult to propagate + # overloads to the private helper functions to enable this check. + # This overload is needed because only TextPrompts are allowed for + # non-completion requests and if we don't add the overload here, + # everywhere this function is used outside of serving_completion will + # need logic asserting that only text prompts are in the request. + if not isinstance(request, + CompletionRequest) and input_or_inputs is not None: + return request_prompts_text, engine_prompts_text + + engine_prompts_embeds = [ + EngineEmbedsPrompt( + prompt_embeds=request_prompt_embeds["prompt_embeds"]) + for request_prompt_embeds in request_prompts_embeds + ] + request_prompts = request_prompts_embeds + request_prompts_text + engine_prompts = engine_prompts_embeds + engine_prompts_text return request_prompts, engine_prompts async def _preprocess_chat( From a56b7f4cbc8d8e5f0aa494b2c7e381de4931b829 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Tue, 13 May 2025 23:34:31 -0500 Subject: [PATCH 93/96] remove vestigial comments Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_completion_with_prompt_embeds.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 341aeb479d99..27e468227bde 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# imports for guided decoding tests import base64 import io import shutil @@ -19,8 +18,6 @@ # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -# technically these adapters use a different base model, -# but we're not testing generation quality here LORA_NAME = "typeof/zephyr-7b-beta-lora" CONFIG = AutoConfig.from_pretrained(MODEL_NAME) From 8c1dde9da5d357a3013fdafa560049072341376e Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 16 May 2025 11:11:32 -0500 Subject: [PATCH 94/96] add documentation for serving prompt embeddings Signed-off-by: Andrew Sansom --- docs/source/serving/prompt_embeds.md | 142 +++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 docs/source/serving/prompt_embeds.md diff --git a/docs/source/serving/prompt_embeds.md b/docs/source/serving/prompt_embeds.md new file mode 100644 index 000000000000..483ca16648a4 --- /dev/null +++ b/docs/source/serving/prompt_embeds.md @@ -0,0 +1,142 @@ +# Prompt Embedding Inputs + +This page teaches you how to pass prompt embedding inputs to vLLM. + +## What are prompt embeddings? + +The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary. + +:::{note} +Prompt embeddings are currently only supported in the v0 engine. +::: + +## Offline Inference + +To input multi-modal data, follow this schema in {class}`vllm.inputs.EmbedsPrompt`: + +- `prompt_embeds`: A torch tensor representing a sequence of prompt/token embeddings. This has the shape (sequence_length, hidden_size), where sequence length is the number of tokens embeddings and hidden_size is the hidden size (embedding size) of the model. + +### Hugging Face Transformers Inputs + +You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples: + +```python +from vllm import LLM +import transformers + +model_name = "meta-llama/Llama-3.2-1B-Instruct" + +# Transformers +tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) +transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) + +llm = LLM(model=model_name, enable_prompt_embeds=True) + +# Refer to the HuggingFace repo for the correct format to use +chat = [{"role": "user", "content": "Please tell me about the capital of France."}] +token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') + +prompt_embeds = embedding_layer(token_ids).squeeze(0) + +# Single prompt inference +outputs = llm.generate({ + "prompt_embeds": prompt_embeds, +}) + +for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + +# Batch inference + +chats = [ + [{"role": "user", "content": "Please tell me about the capital of France."}], + [{"role": "user", "content": "When is the day longest during the year?"}], + [{"role": "user", "content": "Where is bigger, the moon or the sun?"}] +] + +token_ids_list = [ + tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') for chat in chats +] +prompt_embeds_list = [embedding_layer(token_ids).squeeze(0) for token_ids in token_ids_list] + +outputs = llm.generate( + [ + { + "prompt_embeds": prompt_embeds, + } for prompt_embeds in prompt_embeds_list + ] +) + +for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) +``` + +## Online Serving + +Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package. + +When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first. + +Prompt embeddings are passed in as base64 encoded torch tensors. + +### Transformers Inputs via OpenAI Client + +First, launch the OpenAI-compatible server: + +```bash +vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \ + --max-model-len 4096 --enable-prompt-embeds +``` + +Then, you can use the OpenAI client as follows: + +```python +from openai import OpenAI +import transformers +import torch + +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +model_name = "meta-llama/Llama-3.2-1B-Instruct" + +# Transformers +tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) +transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) + + +# Refer to the HuggingFace repo for the correct format to use +chat = [{"role": "user", "content": "Please tell me about the capital of France."}] +token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') + +prompt_embeds = embedding_layer(token_ids).squeeze(0) + +# Prompt embeddings +buffer = io.BytesIO() +torch.save(prompt_embeds, buffer) +buffer.seek(0) +binary_data = buffer.read() +encoded_embeds = base64.b64encode(binary_data).decode('utf-8') + + +completion = client_with_prompt_embeds.completions.create( + model=model_name, + # NOTE: The OpenAI client does not allow `None` as an input to + # `prompt`. Use an empty string if you have no text prompts. + prompt="", + max_tokens=5, + temperature=0.0, + # NOTE: The OpenAI client allows passing in extra JSON body via the + # `extra_body` argument. + extra_body={"prompt_embeds": encoded_embeds} +) + +print(completion.choices[0].text) +``` From 204952cf3f63c5fb0363e6d35f9c850f5d1fd614 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 16 May 2025 11:12:58 -0500 Subject: [PATCH 95/96] remove explicit dependence on v0 for prompt embeddings test since the engine is chosen implicitly Signed-off-by: Andrew Sansom --- .../entrypoints/openai/test_completion_with_prompt_embeds.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 27e468227bde..b7ee3e33c2d2 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -70,10 +70,7 @@ def server_with_prompt_embeds(default_server_args, request): if request.param: default_server_args.append(request.param) - # Prompt embeds are currently only supported on v0 - with RemoteOpenAIServer(MODEL_NAME, - default_server_args, - env_dict={"VLLM_USE_V1": "0"}) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: yield remote_server From 1351bdd398a4c67f4a99d486ecc56f969805fb9d Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 16 May 2025 12:41:26 -0500 Subject: [PATCH 96/96] add prompt embeds docs to toctree Signed-off-by: Andrew Sansom --- docs/source/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/index.md b/docs/source/index.md index bbff7361f752..ca7ec264260e 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -118,6 +118,7 @@ training/rlhf.md serving/offline_inference serving/openai_compatible_server serving/multimodal_inputs +serving/prompt_embeds serving/distributed_serving serving/metrics serving/engine_args