From ed8eb22ca398fa6032a012460fbb1fb20d6f349b Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Sat, 27 Jul 2024 16:23:39 -0500 Subject: [PATCH 01/88] feat: add support for generate from prompt embeddings --- tests/conftest.py | 22 +++--- tests/models/test_models.py | 23 ++++++- tests/worker/test_model_runner.py | 84 ++++++++++++++++++----- vllm/engine/async_llm_engine.py | 8 ++- vllm/engine/llm_engine.py | 15 +++- vllm/inputs/__init__.py | 7 +- vllm/inputs/data.py | 25 ++++++- vllm/model_executor/models/bloom.py | 12 +++- vllm/model_executor/models/gemma.py | 13 ++-- vllm/model_executor/models/gpt2.py | 13 +++- vllm/model_executor/models/gpt_bigcode.py | 11 ++- vllm/model_executor/models/gpt_neox.py | 12 +++- vllm/model_executor/models/llama.py | 16 +++-- vllm/model_executor/models/opt.py | 17 ++++- vllm/model_executor/models/phi.py | 11 ++- vllm/model_executor/models/stablelm.py | 12 +++- vllm/model_executor/models/starcoder2.py | 12 +++- vllm/model_executor/models/utils.py | 20 +++++- vllm/outputs.py | 8 +++ vllm/sequence.py | 29 +++++++- vllm/worker/model_runner.py | 73 ++++++++++++++++---- 21 files changed, 362 insertions(+), 81 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 652d627377786..5076ecec44a4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import os import sys from collections import UserList -from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar +from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union import pytest import torch @@ -18,7 +18,7 @@ from vllm.config import TokenizerPoolConfig from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) -from vllm.inputs import TextPrompt +from vllm.inputs import EmbedsPrompt, TextPrompt from vllm.logger import init_logger from vllm.sequence import SampleLogprobs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, @@ -433,14 +433,18 @@ def __init__( def generate( self, - prompts: List[str], + prompts_or_prompt_embeds: Union[List[str], List[torch.Tensor]], sampling_params: SamplingParams, images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: - assert len(prompts) == len(images) + assert len(prompts_or_prompt_embeds) == len(images) - inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + inputs = [ + EmbedsPrompt(prompt_embeds=prompt) if isinstance( + prompt, torch.Tensor) else TextPrompt(prompt=prompt) + for prompt in prompts_or_prompt_embeds + ] if images is not None: for i, image in enumerate(images): inputs[i]["multi_modal_data"] = {"image": image} @@ -458,7 +462,7 @@ def generate( output_str = sample.text output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) - req_sample_output_strs.append(prompt_str + output_str) + req_sample_output_strs.append((prompt_str or "") + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs @@ -491,12 +495,14 @@ def generate_w_logprobs( def generate_greedy( self, - prompts: List[str], + prompts_or_prompt_embeds: Union[List[str], List[torch.Tensor]], max_tokens: int, images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, greedy_params, images=images) + outputs = self.generate(prompts_or_prompt_embeds, + greedy_params, + images=images) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 4cd2cb665c8f0..bc59fc4ee0886 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,7 +6,6 @@ Run `pytest tests/models/test_models.py`. """ import pytest - from .utils import check_outputs_equal MODELS = [ @@ -39,9 +38,20 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + prompt_embeds = [] + prompt_token_ids = [] + for prompt in example_prompts: + token_ids = hf_model.tokenizer(prompt, + return_tensors="pt").input_ids.to( + hf_model.model.device) + prompt_token_ids.append(token_ids) + prompt_embeds.append( + hf_model.model.get_input_embeddings()(token_ids).squeeze(0)) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + vllm_outputs_from_embeds = vllm_model.generate_greedy( + prompt_embeds, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -50,6 +60,17 @@ def test_models( name_1="vllm", ) + check_outputs_equal( + outputs_0_lst=vllm_outputs, + outputs_1_lst=[ + (prompt_ids.squeeze().tolist() + output_ids, prompt + output_str) + for (output_ids, output_str), prompt_ids, prompt in zip( + vllm_outputs_from_embeds, prompt_token_ids, example_prompts) + ], + name_0="vllm", + name_1="vllm_from_embeds", + ) + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index b5742c4338616..b8a7d55b74da2 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,7 +1,8 @@ from typing import List - +import itertools import pytest import torch +import random from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) @@ -29,8 +30,10 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: return model_runner -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_prompt(batch_size): +@pytest.mark.parametrize("batch_size, prompt_embeds_ratio", + list(itertools.product(range(1, 257), + (0.0, 0.5, 1.0)))) +def test_prepare_prompt(batch_size, prompt_embeds_ratio): model_runner = _create_model_runner( "facebook/opt-125m", max_num_batched_tokens=100000, @@ -41,11 +44,16 @@ def test_prepare_prompt(batch_size): seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] block_tables = {0: [1]} + input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData([], prompt_embeds=torch.rand(seq_len, 10)) + input_embeds_len += seq_len + else: + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -66,6 +74,8 @@ def test_prepare_prompt(batch_size): seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions + input_embeds = model_input.input_embeds + input_embeds_masks = model_input.input_embeds_masks attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens slot_mapping = attn_metadata.slot_mapping @@ -119,7 +129,12 @@ def test_prepare_prompt(batch_size): assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) - torch.testing.assert_close(input_tokens, input_positions) + assert len(input_embeds_masks) == sum(seq_lens) + if input_embeds_len == 0: + torch.testing.assert_close(input_tokens, input_positions) + assert input_embeds is None + else: + assert len(input_embeds) == input_embeds_len sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, @@ -144,7 +159,8 @@ def test_prepare_prompt(batch_size): @pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_decode_cuda_graph(batch_size): +@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) +def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -158,11 +174,17 @@ def test_prepare_decode_cuda_graph(batch_size): context_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] # Assume each seq group finishes prefill. + input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) - seq_data = SequenceData(list(range(context_len))) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData([], + prompt_embeds=torch.rand(context_len, 10)) + input_embeds_len += context_len + else: + seq_data = SequenceData(list(range(context_len))) seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. seq_data.append_token_id(1, 0) @@ -178,9 +200,11 @@ def test_prepare_decode_cuda_graph(batch_size): model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata, slot_mapping = ( + input_tokens, input_positions, input_embeds, input_embeds_masks, attn_metadata, slot_mapping = ( model_input.input_tokens, model_input.input_positions, + model_input.input_embeds, model_input.input_embeds_masks, model_input.attn_metadata, model_input.attn_metadata.slot_mapping) + assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) @@ -232,6 +256,8 @@ def test_prepare_decode_cuda_graph(batch_size): assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs torch.allclose(input_tokens, input_positions) + assert input_embeds is None + assert input_embeds_masks is None # Verify Sampling expected_selected_token_indices = [] @@ -264,14 +290,18 @@ def test_empty_seq_group(): seq_group_metadata_list: List[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata = ( + input_tokens, input_positions, input_embeds, input_embeds_masks, attn_metadata = ( model_input.input_tokens, model_input.input_positions, + model_input.input_embeds, + model_input.input_embeds_masks, model_input.attn_metadata, ) assert input_tokens is None assert input_positions is None assert attn_metadata is None + assert input_embeds is None + assert input_embeds_masks is None model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) @@ -299,7 +329,9 @@ def distributed_init(): @pytest.mark.parametrize("batch_size", list(range(2, 128))) @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_hybrid_batches(batch_size, enforce_eager, distributed_init): +@pytest.mark.parametrize('prompt_embeds_ratio', [0.0, 0.5, 1.0]) +def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, + distributed_init): model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -318,11 +350,16 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): block_tables = {0: [1]} prefill_batch_size = batch_size // 2 decode_batch_size = batch_size - prefill_batch_size + input_embeds_len = 0 for i in range(prefill_batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData([], prompt_embeds=torch.rand(seq_len, 10)) + input_embeds_len += seq_len + else: + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -338,8 +375,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(context_len)) - seq_data = SequenceData(prompt_toks) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData([], + prompt_embeds=torch.rand(context_len, 10)) + else: + seq_data = SequenceData(list(range(context_len))) seq_data.append_token_id(1, 0) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( @@ -354,11 +394,14 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - ) + (input_tokens, input_positions, attn_metadata, input_embeds, + input_embeds_masks) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.input_embeds, + model_input.input_embeds_masks, + ) prefill_meta_actual = attn_metadata.prefill_metadata decode_meta_actual = attn_metadata.decode_metadata @@ -368,6 +411,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert attn_metadata.num_prefills == prefill_batch_size assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_prefill_tokens == sum(seq_lens) + assert len(input_embeds_masks) == sum(seq_lens) + if input_embeds_len == 0: + assert input_embeds is None + else: + assert len(input_embeds) == input_embeds_len # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 16b7bc64a2849..8add0f4dd16d1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -281,7 +281,10 @@ async def process_model_inputs_async( if isinstance(inputs, str): inputs = {"prompt": inputs} - if "prompt_token_ids" not in inputs: + prompt_embeds = inputs.get("prompt_embeds") + if prompt_embeds is not None: + prompt_token_ids = [] + elif "prompt_token_ids" not in inputs: tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") @@ -299,6 +302,7 @@ async def process_model_inputs_async( prompt_token_ids llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, + prompt_embeds=prompt_embeds, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) @@ -719,7 +723,7 @@ async def generate( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: Prompt Adapter request to use + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Yields: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index eabe3b23a9d58..1276adb33eb9a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -91,13 +91,13 @@ class LLMEngine: scheduler_config: The configuration related to the request scheduler. device_config: The configuration related to the device. lora_config (Optional): The configuration related to serving multi-LoRA. - multimodal_config (Optional): The configuration related to multimodal + multimodal_config (Optional): The configuration related to multimodal models. speculative_config (Optional): The configuration related to speculative decoding. executor_class: The model executor class for managing distributed execution. - prompt_adapter_config (Optional): The configuration related to serving + prompt_adapter_config (Optional): The configuration related to serving prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. @@ -576,7 +576,15 @@ def process_model_inputs( if isinstance(inputs, str): inputs = {"prompt": inputs} - if "prompt_token_ids" not in inputs: + prompt_embeds = inputs.get("prompt_embeds") + if prompt_embeds is not None: + if not self.model_executor.driver_worker.model_runner.model_supports_input_embeds: + raise ValueError( + f"Model {self.model_config.model} does not support input embeddings, but prompt_embeds " + "was provided.") + prompt_token_ids = [] + + elif "prompt_token_ids" not in inputs: tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") @@ -592,6 +600,7 @@ def process_model_inputs( + prompt_token_ids llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, + prompt_embeds=prompt_embeds, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index b13d9acf93d3b..d297593d65f6a 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,6 @@ from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, - TextPrompt, TokensPrompt, parse_and_batch_prompt) + TextPrompt, TokensPrompt, EmbedsPrompt, + parse_and_batch_prompt) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -13,6 +14,6 @@ __all__ = [ "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", - "TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY", - "InputContext", "InputRegistry" + "TokensPrompt", "EmbedsPrompt", "PromptInputs", "LLMInputs", + "INPUT_REGISTRY", "InputContext", "InputRegistry" ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 4443e6c70fe5b..25461b3e4e272 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -3,6 +3,8 @@ from typing_extensions import NotRequired +import torch + if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict @@ -33,6 +35,9 @@ def parse_and_batch_prompt( def parse_and_batch_prompt( prompt: Union[str, List[str], List[int], List[List[int]]], ) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: + if len(prompt) == 0: + return [] + if isinstance(prompt, str): # case 1: a string return [ParsedText(content=prompt, is_tokens=False)] @@ -92,7 +97,20 @@ class TokensPrompt(TypedDict): """ -PromptInputs = Union[str, TextPrompt, TokensPrompt] +class EmbedsPrompt(TypedDict): + """Schema for a tokenized prompt.""" + + prompt_embeds: torch.Tensor + """Embeddings of the prompt to pass to the model.""" + + multi_modal_data: NotRequired["MultiModalDataDict"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +PromptInputs = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ The inputs to the LLM, which can take one of the following forms: @@ -114,6 +132,11 @@ class LLMInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ + prompt_embeds: NotRequired[Optional[torch.Tensor]] + """ + The embeddings of the prompt, if available. + """ + multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] """ Optional multi-modal data to pass to the model, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 86ae32e0cb01f..cd6604ac89025 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -41,6 +41,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) @@ -249,8 +251,11 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.word_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.word_embeddings, + inputs_embeds, inputs_embeds_masks) hidden_states = self.word_embeddings_layernorm(hidden_states) for i in range(len(self.h)): layer = self.h[i] @@ -287,9 +292,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 7e0888b5f5abd..c49b0e014f826 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -42,6 +42,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA +from .utils import get_inputs_embeds logger = init_logger(__name__) @@ -279,11 +280,10 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.get_input_embeddings, + inputs_embeds, inputs_embeds_masks) hidden_states *= self.normalizer residual = None for i in range(len(self.layers)): @@ -347,9 +347,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 94cd67e75336a..167f42f878092 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -41,7 +41,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from .utils import is_pp_missing_parameter, make_layers +from .utils import is_pp_missing_parameter, make_layers, get_inputs_embeds class GPT2Attention(nn.Module): @@ -212,9 +212,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - inputs_embeds = self.wte(input_ids) + inputs_embeds = get_inputs_embeds(input_ids, self.wte, + inputs_embeds, + inputs_embeds_masks) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds else: @@ -260,9 +264,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index fc4e13bbb0e68..254024744b980 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -42,6 +42,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA +from .utils import get_inputs_embeds class GPTBigCodeAttention(nn.Module): @@ -219,8 +220,11 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.wte(input_ids) + inputs_embeds = get_inputs_embeds(inputs_embeds, self.wte, + inputs_embeds, inputs_embeds_masks) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds @@ -274,9 +278,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index b306574b2ed92..f2d36977a2d43 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -40,6 +40,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds + class GPTNeoXAttention(nn.Module): @@ -212,8 +214,11 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_in(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_in, + inputs_embeds, inputs_embeds_masks) for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer( @@ -253,9 +258,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d052113e79892..17f058bc8885a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -50,7 +50,7 @@ from vllm.utils import is_hip from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers, get_inputs_embeds class LlamaMLP(nn.Module): @@ -303,12 +303,13 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -416,9 +417,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index edc16710c0229..bdf92acbff54e 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -41,6 +41,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds + class OPTLearnedPositionalEmbedding(nn.Embedding): @@ -243,8 +245,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) pos_embeds = self.embed_positions(positions) if self.project_in is not None: inputs_embeds, _ = self.project_in(inputs_embeds) @@ -278,8 +283,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.decoder(input_ids, positions, kv_caches, attn_metadata) + return self.decoder(input_ids, positions, kv_caches, attn_metadata, + inputs_embeds, inputs_embeds_masks) class OPTForCausalLM(nn.Module): @@ -305,9 +313,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index ac7496f68fd99..1668de1adb10b 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -60,6 +60,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA +from .utils import get_inputs_embeds class PhiAttention(nn.Module): @@ -215,8 +216,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) for i in range(self.config.num_hidden_layers): layer = self.layers[i] hidden_states = layer( @@ -280,9 +284,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 5451b56ed05f7..4a366b67d7c5b 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -43,6 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds + class StablelmMLP(nn.Module): @@ -214,8 +216,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( @@ -253,9 +258,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 1752bfd473b88..cd87cfc09ce80 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -42,6 +42,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds + class Starcoder2Attention(nn.Module): @@ -218,8 +220,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer(positions, hidden_states, kv_caches[i], @@ -263,9 +268,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 197d3839a766a..401f013845f2d 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Protocol, Tuple +from typing import Dict, List, Protocol, Tuple, Optional import torch from torch.func import functional_call @@ -176,3 +176,21 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: if name.startswith(missing_layer_name): return True return False + + +def get_inputs_embeds( + input_ids: torch.Tensor, + embeddings: torch.nn.Module, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Get the input embeddings from either `input_ids` and `inputs_embeds`.""" + if inputs_embeds is not None: + if all(inputs_embeds_masks): + hidden_states = inputs_embeds + else: + hidden_states = embeddings(input_ids) + hidden_states[inputs_embeds_masks] = inputs_embeds + else: + hidden_states = embeddings(input_ids) + return hidden_states diff --git a/vllm/outputs.py b/vllm/outputs.py index 4cb7f06bdb8c7..b9b50e82ccafd 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -83,6 +83,7 @@ def __init__( request_id: str, prompt: Optional[str], prompt_token_ids: List[int], + prompt_embeds_shape: Optional[Tuple[int, int]], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -92,6 +93,7 @@ def __init__( self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.prompt_embeds_shape = prompt_embeds_shape self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished @@ -136,6 +138,10 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # Every sequence in the sequence group should have the same prompt. prompt = seq_group.prompt prompt_token_ids = seq_group.prompt_token_ids + if (prompt_embeds := seq_group.prompt_embeds) is not None: + prompt_embeds_shape = tuple(prompt_embeds.shape) + else: + prompt_embeds_shape = None prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() finished_time = time.time() if finished else None @@ -143,6 +149,7 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": return cls(seq_group.request_id, prompt, prompt_token_ids, + prompt_embeds_shape, prompt_logprobs, outputs, finished, @@ -153,6 +160,7 @@ def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"prompt_embeds_shape={self.prompt_embeds_shape}, " f"prompt_logprobs={self.prompt_logprobs}, " f"outputs={self.outputs}, " f"finished={self.finished}, " diff --git a/vllm/sequence.py b/vllm/sequence.py index 0cd4c7e71d78d..6df64ab4b635d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -117,9 +117,11 @@ class SequenceData: def __init__( self, prompt_token_ids: List[int], + prompt_embeds: Optional[torch.Tensor] = None, output_token_ids: Optional[List[int]] = None, ) -> None: self._prompt_token_ids: List[int] = list(prompt_token_ids) + self._prompt_embeds: Optional[List[torch.Tensor]] = prompt_embeds self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) self._output_token_ids: List[int] = ( list(output_token_ids) if output_token_ids is not None else []) @@ -145,6 +147,14 @@ def prompt_token_ids(self, new_prompt_token_ids) -> None: self._prompt_token_ids_tuple = tuple(new_prompt_token_ids) self._update_cached_all_tokens() + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + return self._prompt_embeds + + @prompt_embeds.setter + def prompt_embeds(self, new_prompt_embeds: Optional[torch.Tensor]) -> None: + self._prompt_embeds = new_prompt_embeds + @property def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @@ -160,7 +170,10 @@ def append_token_id(self, token_id: int, logprob: float) -> None: self.cumulative_logprob += logprob def get_len(self) -> int: - return len(self._output_token_ids) + len(self._prompt_token_ids) + if self._prompt_embeds is None: + return len(self._output_token_ids) + len(self._prompt_token_ids) + else: + return len(self._output_token_ids) + len(self._prompt_embeds) def get_prompt_len(self) -> int: return len(self._prompt_token_ids) @@ -261,7 +274,7 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.data = SequenceData(self.prompt_token_ids) + self.data = SequenceData(self.prompt_token_ids, self.prompt_embeds) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -286,6 +299,10 @@ def prompt(self) -> Optional[str]: def prompt_token_ids(self) -> List[int]: return self.inputs["prompt_token_ids"] + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + return self.inputs.get("prompt_embeds") + @property def multi_modal_data(self) -> "MultiModalDataDict": return self.inputs.get("multi_modal_data") or {} @@ -472,6 +489,12 @@ def prompt_token_ids(self) -> List[int]: # We use the prompt of an arbitrary sequence. return self._first_seq.prompt_token_ids + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return self._first_seq.prompt_embeds + @property def multi_modal_data(self) -> "MultiModalDataDict": # All sequences in the group should have the same multi-modal data. @@ -642,7 +665,7 @@ class SequenceGroupMetadata: state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None + (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. cross_block_table: Optional cross-attention block table associated diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e63be184af16a..3f6afcd5a2b72 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union) +import inspect import numpy as np import torch @@ -86,6 +87,8 @@ class ModelInputForGPU(ModelRunnerInputBase): additional fields. """ input_tokens: Optional[torch.Tensor] = None + input_embeds: Optional[torch.Tensor] = None + input_embeds_masks: Optional[torch.BoolTensor] = None input_positions: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None @@ -102,6 +105,8 @@ class ModelInputForGPU(ModelRunnerInputBase): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "input_embeds": self.input_embeds, + "input_embeds_masks": self.input_embeds_masks, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, @@ -186,6 +191,10 @@ class InterDataForSeqGroup: input_tokens: List[List[int]] = field(default_factory=list) input_positions: List[List[int]] = field(default_factory=list) + # Input embeddings and masks. + input_embeds: Optional[torch.Tensor] = None + input_embeds_mask: Optional[torch.BoolTensor] = None + # The sequence length (may be capped to the sliding window). seq_lens: List[int] = field(default_factory=list) # The original sequence length (before applying sliding window). @@ -299,9 +308,19 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_len - 1 seq_len = min(seq_len, context_len + token_chunk_size) + input_embeds = None + input_embeds_mask = None # Compute tokens. if inter_data.is_prompt: - tokens = seq_data.get_token_ids()[context_len:seq_len] + if seq_data.prompt_embeds is None: + tokens = seq_data.get_token_ids()[context_len:seq_len] + input_embeds_mask = torch.zeros(seq_len - context_len, + dtype=torch.bool) + else: + tokens = [0] * seq_len + input_embeds = seq_data.prompt_embeds[context_len:seq_len] + input_embeds_mask = torch.ones(seq_len - context_len, + dtype=torch.bool) else: # Optimization. get_token_ids requires the entire copy of # tokens. @@ -314,6 +333,8 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.input_positions[seq_idx] = list(range(context_len, seq_len)) inter_data.query_lens[ seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 + inter_data.input_embeds = input_embeds + inter_data.input_embeds_mask = input_embeds_mask def _compute_for_prefix_cache_hit( self, inter_data: InterDataForSeqGroup, seq_idx: int, @@ -335,8 +356,8 @@ def _compute_for_prefix_cache_hit( "chunked prefill cannot be used with prefix caching now.") # If prefix cache is hit, advance context length to bypass - # hit blocks. Accordingly, input tokens, position and query length - # have to be updated. + # hit blocks. Accordingly, input tokens, position, query length + # input_embeds, and input_embeds_mask have to be updated. if prefix_cache_hit: assert computed_block_nums is not None context_len = len(computed_block_nums) * self.block_size @@ -347,6 +368,10 @@ def _compute_for_prefix_cache_hit( inter_data.context_lens[seq_idx] = context_len inter_data.query_lens[ seq_idx] = inter_data.seq_lens[seq_idx] - context_len + if inter_data.input_embeds is not None: + inter_data.input_embeds = inter_data.input_embeds[context_len:] + inter_data.input_embeds_mask = inter_data.input_embeds_mask[ + context_len:] def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, seq_idx: int, @@ -474,6 +499,21 @@ def build(self) -> ModelInputForGPU: flatten_2d_lists(inter_data.input_positions) for inter_data in self.inter_data_list ]) + input_embeds = [ + inter_data.input_embeds for inter_data in self.inter_data_list + if inter_data.input_embeds is not None + ] or None + input_embeds_masks = [ + inter_data.input_embeds_mask for inter_data in self.inter_data_list + if inter_data.input_embeds_mask is not None + ] or None + if input_embeds: + input_embeds = torch.cat(input_embeds).to( + device=self.runner.device, + dtype=self.runner.model_config.dtype) + if input_embeds_masks: + input_embeds_masks = torch.cat(input_embeds_masks).to( + self.runner.device) seq_lens = [] max_decode_seq_len = 0 for inter_data in self.inter_data_list: @@ -575,6 +615,8 @@ def build(self) -> ModelInputForGPU: return self.model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, + input_embeds=input_embeds, + input_embeds_masks=input_embeds_masks, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens, @@ -664,6 +706,7 @@ def __init__( # Lazy initialization self.model: nn.Module # Set after load_model + self.model_supports_input_embeds = False # Set after load_model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None @@ -687,7 +730,9 @@ def load_model(self) -> None: parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, cache_config=self.cache_config) - + model_forward_params = inspect.signature(self.model.forward).parameters + if "inputs_embeds" in model_forward_params and "inputs_embeds_masks" in model_forward_params: + self.model_supports_input_embeds = True self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) @@ -1311,14 +1356,18 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **multi_modal_kwargs, - **seqlen_agnostic_kwargs) + model_params = dict(input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **multi_modal_kwargs, + **seqlen_agnostic_kwargs) + if self.model_supports_input_embeds: + model_params.update( + inputs_embeds=model_input.input_embeds, + inputs_embeds_masks=model_input.input_embeds_masks) + hidden_or_intermediate_states = model_executable(**model_params) # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: From 3bd64235d124dbca70d8fa00c245fb9ccb5ac137 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Sat, 27 Jul 2024 18:04:52 -0500 Subject: [PATCH 02/88] fix ci errors --- tests/models/test_models.py | 1 + tests/worker/test_model_runner.py | 22 ++++++++++++++++++---- vllm/engine/llm_engine.py | 7 ++++--- vllm/inputs/__init__.py | 4 ++-- vllm/inputs/data.py | 3 +-- vllm/model_executor/models/gpt2.py | 2 +- vllm/model_executor/models/llama.py | 3 ++- vllm/model_executor/models/utils.py | 2 +- vllm/worker/model_runner.py | 8 ++++++-- 9 files changed, 36 insertions(+), 16 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index bc59fc4ee0886..0a2d386d7122f 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,6 +6,7 @@ Run `pytest tests/models/test_models.py`. """ import pytest + from .utils import check_outputs_equal MODELS = [ diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 4b551e2e6e4c8..00b69dbda4904 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,8 +1,9 @@ -from typing import List import itertools +import random +from typing import List + import pytest import torch -import random from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) @@ -200,7 +201,14 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, input_embeds, input_embeds_masks, attn_metadata, slot_mapping = ( + ( + input_tokens, + input_positions, + input_embeds, + input_embeds_masks, + attn_metadata, + slot_mapping + ) = ( model_input.input_tokens, model_input.input_positions, model_input.input_embeds, model_input.input_embeds_masks, model_input.attn_metadata, model_input.attn_metadata.slot_mapping) @@ -291,7 +299,13 @@ def test_empty_seq_group(): seq_group_metadata_list: List[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, input_embeds, input_embeds_masks, attn_metadata = ( + ( + input_tokens, + input_positions, + input_embeds, + input_embeds_masks, + attn_metadata + ) = ( model_input.input_tokens, model_input.input_positions, model_input.input_embeds, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 198c62a2fd1ed..9a23505dba1ad 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -583,11 +583,12 @@ def process_model_inputs( inputs = {"prompt": inputs} prompt_embeds = inputs.get("prompt_embeds") + model_runner = self.model_executor.driver_worker.model_runner if prompt_embeds is not None: - if not self.model_executor.driver_worker.model_runner.model_supports_input_embeds: + if not model_runner.model_supports_input_embeds: raise ValueError( - f"Model {self.model_config.model} does not support input embeddings, but prompt_embeds " - "was provided.") + f"Model {self.model_config.model} does not support input " + "embeddings, but prompt_embeds was provided.") prompt_token_ids = [] elif "prompt_token_ids" not in inputs: diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index d297593d65f6a..9e0b6ef29a772 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,5 @@ -from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, - TextPrompt, TokensPrompt, EmbedsPrompt, +from .data import (EmbedsPrompt, LLMInputs, ParsedText, ParsedTokens, + PromptInputs, TextPrompt, TokensPrompt, parse_and_batch_prompt) from .registry import InputContext, InputRegistry diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 25461b3e4e272..decded4ddfcfa 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,9 +1,8 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, TypedDict, Union, cast, overload) -from typing_extensions import NotRequired - import torch +from typing_extensions import NotRequired if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 167f42f878092..1d26c91c1ea03 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -41,7 +41,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from .utils import is_pp_missing_parameter, make_layers, get_inputs_embeds +from .utils import get_inputs_embeds, is_pp_missing_parameter, make_layers class GPT2Attention(nn.Module): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index bd9a06ec7cee2..4db718ef1a4e6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -52,7 +52,8 @@ from vllm.utils import is_hip from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers, get_inputs_embeds +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, + make_layers) class LlamaMLP(nn.Module): diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 401f013845f2d..34b1cebad9c7a 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Protocol, Tuple, Optional +from typing import Dict, List, Optional, Protocol, Tuple import torch from torch.func import functional_call diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b617b18add449..95a0c6a1537bb 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,12 +1,12 @@ import dataclasses import gc +import inspect import time import warnings import weakref from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union) -import inspect import numpy as np import torch @@ -407,6 +407,7 @@ def _compute_for_prefix_cache_hit( seq_idx] = inter_data.seq_lens[seq_idx] - context_len if inter_data.input_embeds is not None: inter_data.input_embeds = inter_data.input_embeds[context_len:] + if inter_data.input_embeds_mask is not None: inter_data.input_embeds_mask = inter_data.input_embeds_mask[ context_len:] @@ -773,7 +774,10 @@ def load_model(self) -> None: scheduler_config=self.scheduler_config, cache_config=self.cache_config) model_forward_params = inspect.signature(self.model.forward).parameters - if "inputs_embeds" in model_forward_params and "inputs_embeds_masks" in model_forward_params: + if ( + "inputs_embeds" in model_forward_params + and "inputs_embeds_masks" in model_forward_params + ): self.model_supports_input_embeds = True self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", From 9a9d406facc4487e27b94a81bce5fd4ee316120e Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Mon, 5 Aug 2024 22:01:15 -0500 Subject: [PATCH 03/88] fix: tensor parallel --- vllm/entrypoints/llm.py | 12 ++++++------ vllm/sequence.py | 2 +- vllm/worker/model_runner.py | 2 ++ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 262cba79e5712..3cf7f56beb7a3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -87,7 +87,7 @@ class LLM: disable_custom_all_reduce: See ParallelConfig **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) - + Note: This class is intended to be used for offline inference. For online serving, use the :class:`~vllm.AsyncLLMEngine` class instead. @@ -277,13 +277,13 @@ def generate( Args: inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If - None, we use the default sampling parameters. - When it is a single value, it is applied to every prompt. - When it is a list, the list must have the same length as the + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Returns: @@ -434,7 +434,7 @@ def encode( use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Returns: diff --git a/vllm/sequence.py b/vllm/sequence.py index 4401156262812..58e393e87a286 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -492,7 +492,7 @@ def prompt_token_ids(self) -> List[int]: def prompt_embeds(self) -> Optional[torch.Tensor]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return self._first_seq.prompt_embeds + return self.seqs[0].prompt_embeds @property def multi_modal_data(self) -> "MultiModalDataDict": diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d1bfe1500a112..9cda876a51534 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -147,6 +147,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "input_embeds": self.input_embeds, + "input_embeds_masks": self.input_embeds_masks, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, From 48fc6a8c4063f2da8088a8f52815337e70a8d5f8 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Mon, 5 Aug 2024 22:17:39 -0500 Subject: [PATCH 04/88] style: yapf --- tests/worker/test_model_runner.py | 38 ++++++++++++------------------- vllm/worker/model_runner.py | 6 ++--- 2 files changed, 16 insertions(+), 28 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 00b69dbda4904..ea181357e6b92 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -201,17 +201,12 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - ( - input_tokens, - input_positions, - input_embeds, - input_embeds_masks, - attn_metadata, - slot_mapping - ) = ( - model_input.input_tokens, model_input.input_positions, - model_input.input_embeds, model_input.input_embeds_masks, - model_input.attn_metadata, model_input.attn_metadata.slot_mapping) + (input_tokens, input_positions, input_embeds, input_embeds_masks, + attn_metadata, + slot_mapping) = (model_input.input_tokens, model_input.input_positions, + model_input.input_embeds, model_input.input_embeds_masks, + model_input.attn_metadata, + model_input.attn_metadata.slot_mapping) assert len(slot_mapping) == len(input_tokens) @@ -299,19 +294,14 @@ def test_empty_seq_group(): seq_group_metadata_list: List[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - ( - input_tokens, - input_positions, - input_embeds, - input_embeds_masks, - attn_metadata - ) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.input_embeds, - model_input.input_embeds_masks, - model_input.attn_metadata, - ) + (input_tokens, input_positions, input_embeds, input_embeds_masks, + attn_metadata) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.input_embeds, + model_input.input_embeds_masks, + model_input.attn_metadata, + ) assert input_tokens is None assert input_positions is None assert attn_metadata is None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9cda876a51534..507cf541a0229 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -776,10 +776,8 @@ def load_model(self) -> None: scheduler_config=self.scheduler_config, cache_config=self.cache_config) model_forward_params = inspect.signature(self.model.forward).parameters - if ( - "inputs_embeds" in model_forward_params - and "inputs_embeds_masks" in model_forward_params - ): + if ("inputs_embeds" in model_forward_params + and "inputs_embeds_masks" in model_forward_params): self.model_supports_input_embeds = True self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", From 737d01bc75eef7beeeda8aa7981f179d5def933d Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Tue, 6 Aug 2024 16:09:38 -0500 Subject: [PATCH 05/88] fix: model_runner in a WorkerWrapper --- vllm/engine/llm_engine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 789a5205e0437..21302097d9c72 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -44,6 +44,7 @@ usage_message) from vllm.utils import Counter from vllm.version import __version__ as VLLM_VERSION +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -573,7 +574,9 @@ def process_model_inputs( inputs = {"prompt": inputs} prompt_embeds = inputs.get("prompt_embeds") - model_runner = self.model_executor.driver_worker.model_runner + driver_worker = self.model_executor.driver_worker + model_runner = driver_worker.worker.model_runner if isinstance( + driver_worker, WorkerWrapperBase) else driver_worker.model_runner if prompt_embeds is not None: if not model_runner.model_supports_input_embeds: raise ValueError( From 03344ab2d28ec994ba778c6a80c1be5f369e96e4 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Wed, 7 Aug 2024 15:50:48 -0500 Subject: [PATCH 06/88] fix: spec decoding model --- vllm/engine/llm_engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 28e4e5b42aa19..109b7d1aded82 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -927,9 +927,12 @@ def _process_decoder_only_prompt( prompt = inputs.get("prompt") prompt_embeds = inputs.get("prompt_embeds") driver_worker = self.model_executor.driver_worker - model_runner = driver_worker.worker.model_runner if isinstance( - driver_worker, WorkerWrapperBase) else driver_worker.model_runner if prompt_embeds is not None: + if self.speculative_config is not None: + raise ValueError( + "Speculative decoding does not support prompt_embeds.") + model_runner = driver_worker.worker.model_runner if isinstance( + driver_worker, WorkerWrapperBase) else driver_worker.model_runner if not model_runner.model_supports_input_embeds: raise ValueError( f"Model {self.model_config.model} does not support input " From 40038e0c7c3873af5dc54c035613df1d3044e223 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Wed, 7 Aug 2024 15:52:35 -0500 Subject: [PATCH 07/88] fix: ruff --- vllm/engine/llm_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 109b7d1aded82..7b679f1af7308 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -932,7 +932,8 @@ def _process_decoder_only_prompt( raise ValueError( "Speculative decoding does not support prompt_embeds.") model_runner = driver_worker.worker.model_runner if isinstance( - driver_worker, WorkerWrapperBase) else driver_worker.model_runner + driver_worker, + WorkerWrapperBase) else driver_worker.model_runner if not model_runner.model_supports_input_embeds: raise ValueError( f"Model {self.model_config.model} does not support input " From b00bbc736e617cf3406f39315676207a11392d0c Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Thu, 8 Aug 2024 11:21:17 -0500 Subject: [PATCH 08/88] fix: move param prompt_embeds_shape to the last of RequestOutput --- vllm/outputs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index bad144c7a2ffd..39158a0d77b41 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -84,6 +84,7 @@ class RequestOutput: None if decoder-only encoder_prompt_token_ids: The token IDs of the encoder prompt; None if decoder-only + prompt_embeds_shape: The shape of the prompt embeddings. """ def __init__( @@ -91,7 +92,6 @@ def __init__( request_id: str, prompt: Optional[str], prompt_token_ids: List[int], - prompt_embeds_shape: Optional[Tuple[int, int]], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -99,6 +99,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, encoder_prompt: Optional[str] = None, encoder_prompt_token_ids: Optional[List[int]] = None, + prompt_embeds_shape: Optional[Tuple[int, int]] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -164,14 +165,14 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": return cls(seq_group.request_id, prompt, prompt_token_ids, - prompt_embeds_shape, prompt_logprobs, outputs, finished, seq_group.metrics, lora_request=seq_group.lora_request, encoder_prompt=encoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids) + encoder_prompt_token_ids=encoder_prompt_token_ids, + prompt_embeds_shape=prompt_embeds_shape) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " From 3c1a6fa01dc530205e93b89533d87831d2c5fcf7 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Thu, 8 Aug 2024 14:42:50 -0500 Subject: [PATCH 09/88] feat: all *ForCausalLM models support inputs_embeds --- vllm/model_executor/models/arctic.py | 10 ++++++++-- vllm/model_executor/models/baichuan.py | 10 ++++++++-- vllm/model_executor/models/chatglm.py | 13 ++++++++----- vllm/model_executor/models/commandr.py | 9 +++++++-- vllm/model_executor/models/dbrx.py | 9 +++++++-- vllm/model_executor/models/deepseek.py | 10 +++++++--- vllm/model_executor/models/deepseek_v2.py | 11 ++++++++--- vllm/model_executor/models/falcon.py | 10 +++++++++- vllm/model_executor/models/fuyu.py | 8 +++++--- vllm/model_executor/models/gemma2.py | 9 +++++++-- vllm/model_executor/models/gpt_bigcode.py | 2 +- vllm/model_executor/models/gpt_j.py | 9 +++++++-- vllm/model_executor/models/internlm2.py | 12 +++++++----- vllm/model_executor/models/jamba.py | 21 ++++++++++++++------- vllm/model_executor/models/minicpm.py | 10 +++++----- vllm/model_executor/models/mixtral.py | 11 ++++++++--- vllm/model_executor/models/mixtral_quant.py | 10 +++++++--- vllm/model_executor/models/mpt.py | 9 +++++++-- vllm/model_executor/models/nemotron.py | 13 +++++++------ vllm/model_executor/models/olmo.py | 12 ++++++++---- vllm/model_executor/models/orion.py | 11 +++++++++-- vllm/model_executor/models/persimmon.py | 9 +++++---- vllm/model_executor/models/phi3_small.py | 12 +++++++++++- vllm/model_executor/models/phi3v.py | 9 +++++++-- vllm/model_executor/models/qwen2.py | 13 +++++++------ vllm/model_executor/models/qwen2_moe.py | 11 ++++++++--- vllm/model_executor/models/xverse.py | 9 +++++++-- 27 files changed, 199 insertions(+), 83 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 49e57a847e847..87f035f4b1702 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -32,6 +32,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.arctic import ArcticConfig +from .utils import get_inputs_embeds + logger = init_logger(__name__) @@ -388,8 +390,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer(positions, hidden_states, kv_caches[i], @@ -428,9 +432,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index e1ea8bfcac655..8506f6789f8da 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -46,6 +46,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA +from .utils import get_inputs_embeds def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -278,8 +279,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -341,9 +345,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 553ddf90475b4..29fcf8789103b 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -29,7 +29,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA - +from .utils import get_inputs_embeds class GLMAttention(nn.Module): @@ -312,12 +312,13 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.embedding(input_ids) - + hidden_states = get_inputs_embeds(input_ids, self.embedding, inputs_embeds, inputs_embeds_masks) # Run encoder. hidden_states = self.encoder( - hidden_states=inputs_embeds, + hidden_states=hidden_states, position_ids=position_ids, kv_caches=kv_caches, attn_metadata=attn_metadata, @@ -367,9 +368,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 5f6e3a134f408..549197c963b27 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -48,6 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds @torch.compile def layer_norm_func(hidden_states, weight, variance_epsilon): @@ -288,8 +289,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -354,9 +357,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index d758333b22388..3f32feb89d78b 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -26,6 +26,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.dbrx import DbrxConfig +from .utils import get_inputs_embeds class DbrxRouter(nn.Module): """A Router implementation for DBRX that returns logits for each expert @@ -338,8 +339,10 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.wte(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) for i in range(len(self.blocks)): block = self.blocks[i] hidden_states = block( @@ -383,9 +386,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 3fd6f2218f3eb..f73c8c33088f5 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -50,7 +50,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput - +from .utils import get_inputs_embeds class DeepseekMLP(nn.Module): def __init__( @@ -353,8 +353,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -390,9 +392,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 2e3e9b6f2792e..2943b2f670ba0 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -50,7 +50,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers, get_inputs_embeds class DeepseekV2MLP(nn.Module): @@ -447,9 +447,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -500,9 +502,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 93f07327eaa26..8a593a78cdd25 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -47,6 +47,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs import RWConfig +from .utils import get_inputs_embeds + FalconConfig = Union[HF_FalconConfig, RWConfig] @@ -361,8 +363,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.word_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.word_embeddings, inputs_embeds, inputs_embeds_masks) for i in range(len(self.h)): layer = self.h[i] hidden_states = layer( @@ -411,12 +415,16 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, positions, kv_caches, attn_metadata, + inputs_embeds, + inputs_embeds_masks, ) return hidden_states diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index c4738263c3056..6b4f9b21da5a8 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from .interfaces import SupportsVision -from .utils import merge_vision_embeddings +from .utils import merge_vision_embeddings, get_inputs_embeds logger = init_logger(__name__) @@ -57,7 +57,7 @@ class FuyuImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor """ - Shape: + Shape: (batch_size, num_patches, patch_size_x * patch_size_y * num_channels) """ @@ -256,6 +256,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, **kwargs: object, ): image_input = self._parse_and_validate_image_input(**kwargs) @@ -263,7 +265,7 @@ def forward( if image_input is not None: vision_embeddings, _ = self.vision_embed_tokens( image_input["data"]) - inputs_embeds = self.language_model.model.embed_tokens(input_ids) + inputs_embeds = get_inputs_embeds(input_ids, self.language_model.model.embed_tokens, inputs_embeds, inputs_embeds_masks) inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, vision_embeddings, self.image_token_id) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 7bad2626fec6a..1f506418f4bda 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -41,6 +41,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA +from .utils import get_inputs_embeds logger = init_logger(__name__) @@ -273,8 +274,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) hidden_states *= self.normalizer residual = None @@ -338,9 +341,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 254024744b980..fbfb34064bc7f 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -223,7 +223,7 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = get_inputs_embeds(inputs_embeds, self.wte, + inputs_embeds = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 4bb9debe7ae81..5ed5cfd056fe6 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -40,6 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds class GPTJAttention(nn.Module): @@ -198,8 +199,10 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.wte(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) for i in range(len(self.h)): layer = self.h[i] hidden_states = layer( @@ -241,9 +244,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 745fbf99a902d..8bd458a3f1aad 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -24,6 +24,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds class InternLM2MLP(nn.Module): @@ -230,11 +231,10 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: IntermediateTensors = None, inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.tok_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.tok_embeddings, + inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -274,9 +274,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: IntermediateTensors, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index cf407c86acd7d..09166e10dd964 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -37,6 +37,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) +from .utils import get_inputs_embeds KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -523,8 +524,10 @@ def forward( attn_metadata: AttentionMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): @@ -628,6 +631,8 @@ def forward(self, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, **kwargs): if not self.mamba_cache: self._prepare_mamba_cache() @@ -661,7 +666,9 @@ def forward(self, hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, current_seqlen_agnostic_cache[0], - current_seqlen_agnostic_cache[1]) + current_seqlen_agnostic_cache[1], + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks) if "seqlen_agnostic_capture_inputs" not in kwargs: self._copy_mamba_cache_by_indices(self.current_indices, @@ -736,9 +743,9 @@ def _prepare_current_run_mamba_cache( def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (JambaForCausalLM.mamba_gc_cache_buffer). + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (JambaForCausalLM.mamba_gc_cache_buffer). """ assert all( key in kwargs @@ -763,7 +770,7 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): """ Copy the relevant Mamba cache from the CUDA graph input_buffers - back to the JambaForCausalLM.mamba_cache after CUDA + back to the JambaForCausalLM.mamba_cache after CUDA graph replay run is done. """ self._copy_mamba_cache_by_indices( @@ -773,7 +780,7 @@ def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph + The buffer is used to maintain the Mamba Cache during the CUDA graph replay runs. """ return tuple(buffer[:, :batch_size] diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 7f8f38fe8439a..79df63fc1d395 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA - +from .utils import get_inputs_embeds class MiniCPMMoE(nn.Module): """A tensor-parallel MoE implementation that shards each expert @@ -372,11 +372,9 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.get_input_embeddings, inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): @@ -465,6 +463,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 8fbd537a2c031..f047e9e9ad1a2 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers +from .utils import is_pp_missing_parameter, make_layers, get_inputs_embeds class MixtralMoE(nn.Module): @@ -283,9 +283,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -370,9 +372,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 10faa5cc6b6cc..50bcfcd085a56 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput - +from .utils import get_inputs_embeds class MixtralMLP(nn.Module): def __init__( @@ -319,8 +319,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -357,9 +359,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 7d658b39e6794..ba8d47e29593e 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -25,6 +25,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.transformers_utils.configs.mpt import MPTConfig +from .utils import get_inputs_embeds def _get_alibi_slopes( total_num_heads: int, @@ -235,8 +236,10 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.wte(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) for i in range(len(self.blocks)): block = self.blocks[i] hidden_states = block( @@ -274,9 +277,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index bb85f20ab9802..008686957cde2 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -47,7 +47,7 @@ from vllm.transformers_utils.configs import NemotronConfig from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers, get_inputs_embeds # The architecture is pretty similar to Llama, with these changes: # - There is no gate_proj, just up_proj @@ -340,12 +340,10 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.get_input_embeddings, inputs_embeds, inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -448,9 +446,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 1a0a3774dc8fb..fd486e70848eb 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -45,6 +45,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds class OlmoAttention(nn.Module): """ @@ -243,16 +244,15 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. """ # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - inputs_embeds = self.embed_tokens(input_ids) - - # embed positions - hidden_states = inputs_embeds + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) # Apply blocks one-by-one. for layer_idx, decoder_layer in enumerate(self.layers): @@ -302,12 +302,16 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + inputs_embeds, + inputs_embeds_masks, ) return hidden_states diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 8159cc13fba0b..b43adbf65be32 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -28,6 +28,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds + class OrionMLP(nn.Module): @@ -231,8 +233,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, + inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -272,9 +277,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index bc38d4421b79e..eaff26b28a205 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -44,6 +44,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds + class PersimmonMLP(nn.Module): @@ -233,11 +235,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) for i in range(len(self.layers)): hidden_states = self.layers[i]( positions, @@ -275,6 +275,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ): hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index cc06929fefab4..e04558c6e1aa3 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -23,6 +23,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import get_inputs_embeds + def load_column_parallel_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor): @@ -301,6 +303,8 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ): super().__init__() self.config = config @@ -328,8 +332,10 @@ def forward( positions: Optional[torch.LongTensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ): - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) if (self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0): hidden_states = hidden_states * self.mup_embedding_multiplier @@ -414,12 +420,16 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: output_hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks, ) output_hidden_states = output_hidden_states return output_hidden_states diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 823c34b101870..06ff89a7d368d 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -43,7 +43,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, input_processor_for_clip) from .interfaces import SupportsVision -from .utils import merge_vision_embeddings +from .utils import merge_vision_embeddings, get_inputs_embeds logger = init_logger(__name__) @@ -522,6 +522,8 @@ def forward(self, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, **kwargs: object): image_input = self._parse_and_validate_image_input(**kwargs) @@ -529,10 +531,12 @@ def forward(self, vision_embeddings = self.vision_embed_tokens( image_input["data"], image_input["image_sizes"]) inputs_embeds = self.model.get_input_embeddings(input_ids) + inputs_embeds = get_inputs_embeds(input_ids, self.model.get_input_embeddings, inputs_embeds, inputs_embeds_masks) inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, vision_embeddings, self.image_token_id) input_ids = None + inputs_embeds_masks.fill_(1) else: inputs_embeds = None @@ -541,7 +545,8 @@ def forward(self, kv_caches, attn_metadata, intermediate_tensors, - inputs_embeds=inputs_embeds) + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks) return hidden_states diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index a66a1eee7c160..e9f3705391cd5 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers +from .utils import is_pp_missing_parameter, make_layers, get_inputs_embeds class Qwen2MLP(nn.Module): @@ -261,12 +261,10 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -357,9 +355,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index b895788206d10..cff144c4c4c2c 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once -from .utils import is_pp_missing_parameter, make_layers +from .utils import is_pp_missing_parameter, make_layers, get_inputs_embeds class Qwen2MoeMLP(nn.Module): @@ -346,9 +346,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -395,9 +397,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 84f0ffc376d65..a2f561fba339b 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -46,7 +46,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA - +from .utils import get_inputs_embeds class XverseMLP(nn.Module): @@ -252,8 +252,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -323,9 +326,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, From c454647906cb09dfa0f43cbd9fc46d36ec48c4b9 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Thu, 8 Aug 2024 14:43:53 -0500 Subject: [PATCH 10/88] fix: format --- vllm/model_executor/models/arctic.py | 6 ++++-- vllm/model_executor/models/baichuan.py | 3 ++- vllm/model_executor/models/chatglm.py | 7 +++++-- vllm/model_executor/models/commandr.py | 7 +++++-- vllm/model_executor/models/dbrx.py | 7 +++++-- vllm/model_executor/models/deepseek.py | 8 ++++++-- vllm/model_executor/models/deepseek_v2.py | 4 +++- vllm/model_executor/models/falcon.py | 3 ++- vllm/model_executor/models/fuyu.py | 4 +++- vllm/model_executor/models/gemma2.py | 6 ++++-- vllm/model_executor/models/gpt_bigcode.py | 4 ++-- vllm/model_executor/models/gpt_j.py | 7 +++++-- vllm/model_executor/models/internlm2.py | 4 +++- vllm/model_executor/models/jamba.py | 7 +++++-- vllm/model_executor/models/minicpm.py | 4 +++- vllm/model_executor/models/mixtral.py | 4 +++- vllm/model_executor/models/mixtral_quant.py | 8 ++++++-- vllm/model_executor/models/mpt.py | 7 +++++-- vllm/model_executor/models/nemotron.py | 5 ++++- vllm/model_executor/models/olmo.py | 4 +++- vllm/model_executor/models/orion.py | 7 ++++--- vllm/model_executor/models/persimmon.py | 3 ++- vllm/model_executor/models/phi3_small.py | 3 ++- vllm/model_executor/models/phi3v.py | 5 ++++- vllm/model_executor/models/qwen2.py | 4 +++- vllm/model_executor/models/qwen2_moe.py | 4 +++- vllm/model_executor/models/xverse.py | 7 +++++-- 27 files changed, 101 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 87f035f4b1702..024dedc0b0e9d 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -393,7 +393,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer(positions, hidden_states, kv_caches[i], @@ -436,7 +437,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 8506f6789f8da..a303801527905 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -349,7 +349,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 29fcf8789103b..7169ddb9a272c 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -31,6 +31,7 @@ from .interfaces import SupportsLoRA from .utils import get_inputs_embeds + class GLMAttention(nn.Module): def __init__( @@ -315,7 +316,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.embedding, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embedding, + inputs_embeds, inputs_embeds_masks) # Run encoder. hidden_states = self.encoder( hidden_states=hidden_states, @@ -372,7 +374,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 549197c963b27..27c668b58d4dc 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -50,6 +50,7 @@ from .utils import get_inputs_embeds + @torch.compile def layer_norm_func(hidden_states, weight, variance_epsilon): input_dtype = hidden_states.dtype @@ -292,7 +293,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -361,7 +363,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 3f32feb89d78b..dfd18be9bcb56 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -28,6 +28,7 @@ from .utils import get_inputs_embeds + class DbrxRouter(nn.Module): """A Router implementation for DBRX that returns logits for each expert per token. @@ -342,7 +343,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, + inputs_embeds_masks) for i in range(len(self.blocks)): block = self.blocks[i] hidden_states = block( @@ -390,7 +392,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index f73c8c33088f5..f3ac2e87c9027 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -51,6 +51,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .utils import get_inputs_embeds + + class DeepseekMLP(nn.Module): def __init__( @@ -356,7 +358,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -396,7 +399,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 2943b2f670ba0..bbb604db61758 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -451,7 +451,9 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 8a593a78cdd25..bf6577cfb98f7 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -366,7 +366,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.word_embeddings, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.word_embeddings, + inputs_embeds, inputs_embeds_masks) for i in range(len(self.h)): layer = self.h[i] hidden_states = layer( diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 6b4f9b21da5a8..1e7863c8438ef 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -265,7 +265,9 @@ def forward( if image_input is not None: vision_embeddings, _ = self.vision_embed_tokens( image_input["data"]) - inputs_embeds = get_inputs_embeds(input_ids, self.language_model.model.embed_tokens, inputs_embeds, inputs_embeds_masks) + inputs_embeds = get_inputs_embeds( + input_ids, self.language_model.model.embed_tokens, + inputs_embeds, inputs_embeds_masks) inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, vision_embeddings, self.image_token_id) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 1f506418f4bda..41a5a20c7445a 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -277,7 +277,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) hidden_states *= self.normalizer residual = None @@ -345,7 +346,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index fbfb34064bc7f..24969a46048ce 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -223,8 +223,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = get_inputs_embeds(input_ids, self.wte, - inputs_embeds, inputs_embeds_masks) + inputs_embeds = get_inputs_embeds(input_ids, self.wte, inputs_embeds, + inputs_embeds_masks) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 5ed5cfd056fe6..b9eb8b5480ea2 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -42,6 +42,7 @@ from .utils import get_inputs_embeds + class GPTJAttention(nn.Module): def __init__( @@ -202,7 +203,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, + inputs_embeds_masks) for i in range(len(self.h)): layer = self.h[i] hidden_states = layer( @@ -248,7 +250,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 8bd458a3f1aad..2efc50460f55a 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -26,6 +26,7 @@ from .utils import get_inputs_embeds + class InternLM2MLP(nn.Module): def __init__( @@ -278,7 +279,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 09166e10dd964..6d4c2f6c10655 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -527,7 +527,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): @@ -663,7 +664,9 @@ def forward(self, ) self.current_indices = indices - hidden_states = self.model(input_ids, positions, kv_caches, + hidden_states = self.model(input_ids, + positions, + kv_caches, attn_metadata, current_seqlen_agnostic_cache[0], current_seqlen_agnostic_cache[1], diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 79df63fc1d395..7636818e07220 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -55,6 +55,7 @@ from .interfaces import SupportsLoRA from .utils import get_inputs_embeds + class MiniCPMMoE(nn.Module): """A tensor-parallel MoE implementation that shards each expert across all ranks. @@ -374,7 +375,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.get_input_embeddings, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.get_input_embeddings, + inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index f047e9e9ad1a2..7d64e91a11728 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -287,7 +287,9 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 50bcfcd085a56..e37a94b8d86e0 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -50,6 +50,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .utils import get_inputs_embeds + + class MixtralMLP(nn.Module): def __init__( @@ -322,7 +324,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -363,7 +366,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index ba8d47e29593e..8d6fa610fe4b4 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -27,6 +27,7 @@ from .utils import get_inputs_embeds + def _get_alibi_slopes( total_num_heads: int, alibi_bias_max: int, @@ -239,7 +240,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, + inputs_embeds_masks) for i in range(len(self.blocks)): block = self.blocks[i] hidden_states = block( @@ -281,7 +283,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 008686957cde2..300c9d6796de1 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -343,7 +343,10 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, self.get_input_embeddings, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index fd486e70848eb..e83062a3f656e 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -47,6 +47,7 @@ from .utils import get_inputs_embeds + class OlmoAttention(nn.Module): """ This is the attention block where the output is computed as @@ -252,7 +253,8 @@ def forward( """ # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) # Apply blocks one-by-one. for layer_idx, decoder_layer in enumerate(self.layers): diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index b43adbf65be32..e5a1b88784754 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -236,8 +236,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, - inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -281,7 +281,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index eaff26b28a205..efcd808e135c4 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -237,7 +237,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) for i in range(len(self.layers)): hidden_states = self.layers[i]( positions, diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index e04558c6e1aa3..ef8191bbb5324 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -335,7 +335,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ): - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) if (self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0): hidden_states = hidden_states * self.mup_embedding_multiplier diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 06ff89a7d368d..16e3501bab658 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -531,7 +531,10 @@ def forward(self, vision_embeddings = self.vision_embed_tokens( image_input["data"], image_input["image_sizes"]) inputs_embeds = self.model.get_input_embeddings(input_ids) - inputs_embeds = get_inputs_embeds(input_ids, self.model.get_input_embeddings, inputs_embeds, inputs_embeds_masks) + inputs_embeds = get_inputs_embeds(input_ids, + self.model.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, vision_embeddings, self.image_token_id) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e9f3705391cd5..75a30bcc39c7c 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -264,7 +264,9 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index cff144c4c4c2c..5d297422154f5 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -350,7 +350,9 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index a2f561fba339b..9819a289fdbae 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -48,6 +48,7 @@ from .interfaces import SupportsLoRA from .utils import get_inputs_embeds + class XverseMLP(nn.Module): def __init__( @@ -256,7 +257,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -330,7 +332,8 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, inputs_embeds_masks) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, From 535ad974f940ccd469318ebe958199eec29f12cd Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Thu, 8 Aug 2024 14:45:25 -0500 Subject: [PATCH 11/88] fix: format --- vllm/model_executor/models/deepseek_v2.py | 3 ++- vllm/model_executor/models/nemotron.py | 3 ++- vllm/model_executor/models/olmo.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index bbb604db61758..863ab97d23527 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -50,7 +50,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers, get_inputs_embeds +from .utils import ( + PPMissingLayer, is_pp_missing_parameter, make_layers, get_inputs_embeds) class DeepseekV2MLP(nn.Module): diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 300c9d6796de1..a3fa75e63a45d 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -47,7 +47,8 @@ from vllm.transformers_utils.configs import NemotronConfig from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers, get_inputs_embeds +from .utils import ( + PPMissingLayer, is_pp_missing_parameter, make_layers, get_inputs_embeds) # The architecture is pretty similar to Llama, with these changes: # - There is no gate_proj, just up_proj diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index e83062a3f656e..5f92da79c4e69 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -312,8 +312,8 @@ def forward( positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, - inputs_embeds, - inputs_embeds_masks, + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks, ) return hidden_states From c05a8ff49c88f33e0ffdd153da1d1aff30c0abec Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Thu, 8 Aug 2024 14:46:53 -0500 Subject: [PATCH 12/88] fix: format --- vllm/model_executor/models/deepseek_v2.py | 4 ++-- vllm/model_executor/models/fuyu.py | 2 +- vllm/model_executor/models/jamba.py | 1 + vllm/model_executor/models/mixtral.py | 2 +- vllm/model_executor/models/nemotron.py | 4 ++-- vllm/model_executor/models/phi3v.py | 2 +- vllm/model_executor/models/qwen2.py | 2 +- vllm/model_executor/models/qwen2_moe.py | 2 +- 8 files changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 863ab97d23527..3deb0b85e35d8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -50,8 +50,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from .utils import ( - PPMissingLayer, is_pp_missing_parameter, make_layers, get_inputs_embeds) +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, + make_layers) class DeepseekV2MLP(nn.Module): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 1e7863c8438ef..29ca21d4f9c06 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from .interfaces import SupportsVision -from .utils import merge_vision_embeddings, get_inputs_embeds +from .utils import get_inputs_embeds, merge_vision_embeddings logger = init_logger(__name__) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 6d4c2f6c10655..c14f6f5deb33e 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -37,6 +37,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) + from .utils import get_inputs_embeds KVCache = Tuple[torch.Tensor, torch.Tensor] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 7d64e91a11728..99c9f641622aa 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers, get_inputs_embeds +from .utils import get_inputs_embeds, is_pp_missing_parameter, make_layers class MixtralMoE(nn.Module): diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index a3fa75e63a45d..d6ad7774886fd 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -47,8 +47,8 @@ from vllm.transformers_utils.configs import NemotronConfig from .interfaces import SupportsLoRA -from .utils import ( - PPMissingLayer, is_pp_missing_parameter, make_layers, get_inputs_embeds) +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, + make_layers) # The architecture is pretty similar to Llama, with these changes: # - There is no gate_proj, just up_proj diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 16e3501bab658..44ab97c4d7d7d 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -43,7 +43,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, input_processor_for_clip) from .interfaces import SupportsVision -from .utils import merge_vision_embeddings, get_inputs_embeds +from .utils import get_inputs_embeds, merge_vision_embeddings logger = init_logger(__name__) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 75a30bcc39c7c..6e3c9d7d37fa1 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers, get_inputs_embeds +from .utils import get_inputs_embeds, is_pp_missing_parameter, make_layers class Qwen2MLP(nn.Module): diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 5d297422154f5..a238be3495ca2 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once -from .utils import is_pp_missing_parameter, make_layers, get_inputs_embeds +from .utils import get_inputs_embeds, is_pp_missing_parameter, make_layers class Qwen2MoeMLP(nn.Module): From d83915baa602695dd7bf8f96068861859b3668f7 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Wed, 4 Sep 2024 16:23:53 -0500 Subject: [PATCH 13/88] fix: engines --- vllm/engine/async_llm_engine.py | 1 + vllm/engine/llm_engine.py | 5 +++-- vllm/outputs.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index cf8b45911e945..f9a0637bc7b6a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -434,6 +434,7 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) multi_modal_data = None + prompt_embeds = None elif isinstance(inputs, dict): prompt = inputs.get("prompt") prompt_embeds = inputs.get("prompt_embeds") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bb1734989bbe5..493b1a9fe4d97 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -987,7 +987,7 @@ def _build_decoder_only_llm_inputs( return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=prompt, - prompt_embeds=prompt_embeds + prompt_embeds=prompt_embeds, multi_modal_data=multi_modal_data) def _process_decoder_only_prompt( @@ -2013,8 +2013,9 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs, prompt_ids = inputs.get("encoder_prompt_token_ids") else: prompt_ids = inputs.get("prompt_token_ids") + prompt_embeds = inputs.get("prompt_embeds") - if prompt_ids is None or len(prompt_ids) == 0: + if (prompt_ids is None or len(prompt_ids) == 0) and prompt_embeds is None: raise ValueError("Prompt cannot be empty") if self.model_config.is_multimodal_model: diff --git a/vllm/outputs.py b/vllm/outputs.py index 64b0e4d56586f..b3f8e29353950 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple from typing import Sequence as GenericSequence from typing import Union From dfd93011f153d49e9f9da88e8aa35e97baf4e3c8 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Wed, 4 Sep 2024 17:27:51 -0500 Subject: [PATCH 14/88] fix: format --- tests/conftest.py | 5 ++-- tests/worker/test_model_runner.py | 34 +++++++++++++------------ vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 12 +++++---- vllm/inputs/__init__.py | 9 ++++--- vllm/model_executor/models/fuyu.py | 7 +++-- vllm/model_executor/models/jamba.py | 7 +++-- vllm/model_executor/models/persimmon.py | 14 +++++----- vllm/model_executor/models/phi3v.py | 5 ++-- vllm/sequence.py | 3 ++- 10 files changed, 52 insertions(+), 46 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 74bef1d5d9e95..f0d91030b8d3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,8 +27,9 @@ destroy_model_parallel, init_distributed_environment, initialize_model_parallel) -from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, EmbedsPrompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) +from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, + EmbedsPrompt, to_enc_dec_tuple_list, + zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f8f172657a87a..d75aa617638ee 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -34,9 +34,8 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: return model_runner -@pytest.mark.parametrize("batch_size, prompt_embeds_ratio", - list(itertools.product(range(1, 257), - (0.0, 0.5, 1.0)))) +@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) +@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) def test_prepare_prompt(batch_size, prompt_embeds_ratio): model_runner = _create_model_runner( "facebook/opt-125m", @@ -54,11 +53,13 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio): seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceData([], prompt_embeds=torch.rand(seq_len, 10)) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, range(seq_len)), + torch.rand(seq_len, 10)) input_embeds_len += seq_len - else - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - range(seq_len))) + else: + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -163,7 +164,7 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio): torch.testing.assert_close(actual, expected) -@pytest.mark.parametrize("batch_size", list(range(1, 257))) +@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) @pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): model_runner = _create_model_runner( @@ -185,8 +186,8 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceData([], - prompt_embeds=torch.rand(context_len, 10)) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, range(0)), + torch.rand(context_len, 10)) input_embeds_len += context_len else: seq_data = SequenceData( @@ -337,7 +338,7 @@ def distributed_init(): ensure_model_parallel_initialized(1, 1) -@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("batch_size", list(range(2, 128, 3))) @pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize('prompt_embeds_ratio', [0.0, 0.5, 1.0]) def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, @@ -366,11 +367,12 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceData([], prompt_embeds=torch.rand(seq_len, 10)) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, range(0)), + torch.rand(seq_len, 10)) input_embeds_len += seq_len else: - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - range(seq_len))) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -387,8 +389,8 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 if random.random() < prompt_embeds_ratio: - seq_data = SequenceData([], - prompt_embeds=torch.rand(context_len, 10)) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, range(0)), + torch.rand(context_len, 10)) else: prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)) seq_data = SequenceData(prompt_toks) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f9a0637bc7b6a..c8f744f82fd75 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -436,7 +436,6 @@ async def _extract_prompt_components_async( multi_modal_data = None prompt_embeds = None elif isinstance(inputs, dict): - prompt = inputs.get("prompt") prompt_embeds = inputs.get("prompt_embeds") driver_worker = self.model_executor.driver_worker if prompt_embeds is not None: @@ -450,6 +449,7 @@ async def _extract_prompt_components_async( raise ValueError( f"Model {self.model_config.model} does not support input " "embeddings, but prompt_embeds was provided.") + prompt = None prompt_token_ids = [] elif "prompt_token_ids" in inputs: prompt = None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 493b1a9fe4d97..9174d12ab3092 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -76,7 +76,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) -PromptComponents = Tuple[Optional[str], List[int], +PromptComponents = Tuple[Optional[str], List[int], Optional[torch.Tensor], Optional[MultiModalDataDict]] DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], Optional[MultiModalDataDict]] @@ -808,7 +808,6 @@ def _extract_prompt_components( multi_modal_data = None prompt_embeds = None elif isinstance(inputs, dict): - prompt = inputs.get("prompt") prompt_embeds = inputs.get("prompt_embeds") driver_worker = self.model_executor.driver_worker if prompt_embeds is not None: @@ -822,6 +821,7 @@ def _extract_prompt_components( raise ValueError( f"Model {self.model_config.model} does not support input " "embeddings, but prompt_embeds was provided.") + prompt = None prompt_token_ids = [] elif "prompt_token_ids" in inputs: prompt = None @@ -894,7 +894,7 @@ def _build_enc_dec_llm_inputs( encoder_comps: PromptComponents, decoder_comps: DecoderPromptComponents, ) -> EncoderDecoderLLMInputs: - encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps + encoder_prompt, encoder_prompt_ids, _, encoder_mm_data = encoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps if encoder_mm_data is not None or decoder_mm_data is not None: @@ -961,10 +961,11 @@ def _process_encoder_decoder_prompt( if (decoder_input := inputs["decoder_prompt"]) is None: decoder_comps = None, None, None else: - decoder_comps = self._extract_prompt_components( + prompt, prompt_token_ids, _, multi_modal_data = self._extract_prompt_components( decoder_input, request_id=request_id, ) + decoder_comps = prompt, prompt_token_ids, multi_modal_data else: encoder_comps = self._extract_prompt_components( inputs, @@ -2015,7 +2016,8 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs, prompt_ids = inputs.get("prompt_token_ids") prompt_embeds = inputs.get("prompt_embeds") - if (prompt_ids is None or len(prompt_ids) == 0) and prompt_embeds is None: + if (prompt_ids is None + or len(prompt_ids) == 0) and prompt_embeds is None: raise ValueError("Prompt cannot be empty") if self.model_config.is_multimodal_model: diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 6d93b4bf06504..d83878019138d 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,8 @@ -from .data import (EmbedsPrompt, EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, - TokensPrompt, build_explicit_enc_dec_prompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) +from .data import (EmbedsPrompt, EncoderDecoderLLMInputs, + ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, + SingletonPromptInputs, TextPrompt, TokensPrompt, + build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 027885d9e19f3..a2fd67e79b018 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -284,8 +284,8 @@ def forward( ): image_input = self._parse_and_validate_image_input(**kwargs) inputs_embeds = get_inputs_embeds( - input_ids, self.language_model.model.embed_tokens, - inputs_embeds, inputs_embeds_masks) + input_ids, self.language_model.model.embed_tokens, inputs_embeds, + inputs_embeds_masks) if image_input is not None: vision_embeddings = self._process_image_input(image_input) inputs_embeds = merge_multimodal_embeddings( @@ -298,8 +298,7 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, - inputs_embeds_masks=inputs_embeds_masks - ) + inputs_embeds_masks=inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index cc1e23fc30bad..49df9481f35aa 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -637,8 +637,11 @@ def forward(self, # CUDA graph capturing runs mamba_cache = kwargs["seqlen_agnostic_capture_inputs"] - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache[0], + hidden_states = self.model(input_ids, + positions, + kv_caches, + attn_metadata, + mamba_cache[0], mamba_cache[1], inputs_embeds=inputs_embeds, inputs_embeds_masks=inputs_embeds_masks) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 3a31f18f63620..e75621ec50b1f 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -278,14 +278,12 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ): - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - inputs_embeds=inputs_embeds, - inputs_embeds_masks=inputs_embeds_masks - ) + hidden_states = self.model(input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 1744e9a45a1b8..a86cc723d51ff 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -620,9 +620,8 @@ def forward(self, **kwargs: object): image_input = self._parse_and_validate_image_input(**kwargs) inputs_embeds = get_inputs_embeds(input_ids, - self.model.get_input_embeddings, - inputs_embeds, - inputs_embeds_masks) + self.model.get_input_embeddings, + inputs_embeds, inputs_embeds_masks) if image_input is not None: vision_embeddings = self._process_image_input(image_input) inputs_embeds = merge_multimodal_embeddings( diff --git a/vllm/sequence.py b/vllm/sequence.py index 12f88b026bb8d..58c09851bc971 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -402,7 +402,8 @@ def __init__( "encoder input prompt fields?") self.data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids), self.prompt_embeds) + array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids), + self.prompt_embeds) self.output_logprobs: SampleLogprobs = [] self.output_text = "" From fd455ebafac514b8063fbbd43dee9f98e7fdf739 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Wed, 4 Sep 2024 17:36:15 -0500 Subject: [PATCH 15/88] fix: format --- tests/conftest.py | 4 ++-- tests/worker/test_model_runner.py | 1 - vllm/engine/async_llm_engine.py | 8 ++++---- vllm/engine/llm_engine.py | 14 +++++++------- vllm/outputs.py | 4 ++-- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f0d91030b8d3a..c6b901d31333c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,8 +27,8 @@ destroy_model_parallel, init_distributed_environment, initialize_model_parallel) -from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - EmbedsPrompt, to_enc_dec_tuple_list, +from vllm.inputs import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, + TextPrompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index d75aa617638ee..391cb2725b6e7 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,4 +1,3 @@ -import itertools import random from array import array from typing import List diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index c8f744f82fd75..92613eaee46fe 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -447,8 +447,8 @@ async def _extract_prompt_components_async( WorkerWrapperBase) else driver_worker.model_runner if not model_runner.model_supports_input_embeds: raise ValueError( - f"Model {self.model_config.model} does not support input " - "embeddings, but prompt_embeds was provided.") + f"Model {self.model_config.model} does not support " + "input embeddings, but prompt_embeds was provided.") prompt = None prompt_token_ids = [] elif "prompt_token_ids" in inputs: @@ -486,7 +486,7 @@ async def _process_encoder_decoder_prompt_async( if (decoder_input := inputs["decoder_prompt"]) is None: encoder_comps = await encoder_task - decoder_comps = None, None, None + decoder_comps = None, None, None, None else: decoder_task = self._extract_prompt_components_async( decoder_input, @@ -501,7 +501,7 @@ async def _process_encoder_decoder_prompt_async( request_id=request_id, ) - decoder_comps = None, None, None + decoder_comps = None, None, None, None return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9174d12ab3092..81316473497b7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -79,6 +79,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: PromptComponents = Tuple[Optional[str], List[int], Optional[torch.Tensor], Optional[MultiModalDataDict]] DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], + Optional[torch.Tensor], Optional[MultiModalDataDict]] @@ -819,8 +820,8 @@ def _extract_prompt_components( WorkerWrapperBase) else driver_worker.model_runner if not model_runner.model_supports_input_embeds: raise ValueError( - f"Model {self.model_config.model} does not support input " - "embeddings, but prompt_embeds was provided.") + f"Model {self.model_config.model} does not support " + "input embeddings, but prompt_embeds was provided.") prompt = None prompt_token_ids = [] elif "prompt_token_ids" in inputs: @@ -895,7 +896,7 @@ def _build_enc_dec_llm_inputs( decoder_comps: DecoderPromptComponents, ) -> EncoderDecoderLLMInputs: encoder_prompt, encoder_prompt_ids, _, encoder_mm_data = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps + decoder_prompt, decoder_prompt_ids, _, decoder_mm_data = decoder_comps if encoder_mm_data is not None or decoder_mm_data is not None: raise ValueError("Multi-modal encoder-decoder models are " @@ -959,20 +960,19 @@ def _process_encoder_decoder_prompt( ) if (decoder_input := inputs["decoder_prompt"]) is None: - decoder_comps = None, None, None + decoder_comps = None, None, None, None else: - prompt, prompt_token_ids, _, multi_modal_data = self._extract_prompt_components( + decoder_comps = self._extract_prompt_components( decoder_input, request_id=request_id, ) - decoder_comps = prompt, prompt_token_ids, multi_modal_data else: encoder_comps = self._extract_prompt_components( inputs, request_id=request_id, ) - decoder_comps = None, None, None + decoder_comps = None, None, None, None return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) diff --git a/vllm/outputs.py b/vllm/outputs.py index b3f8e29353950..11b5813f89e45 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,8 +1,8 @@ import time from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Optional from typing import Sequence as GenericSequence -from typing import Union +from typing import Tuple, Union from vllm.lora.request import LoRARequest from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, From 03fcf3bcd57380d7e8ce833bbc6b8d1206141611 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Thu, 5 Sep 2024 14:40:02 -0500 Subject: [PATCH 16/88] fix: sequence property and embeddings for phi3v --- vllm/model_executor/models/phi3v.py | 8 ++++---- vllm/sequence.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a86cc723d51ff..65d013144f709 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -619,16 +619,16 @@ def forward(self, inputs_embeds_masks: Optional[torch.Tensor] = None, **kwargs: object): image_input = self._parse_and_validate_image_input(**kwargs) - inputs_embeds = get_inputs_embeds(input_ids, - self.model.get_input_embeddings, - inputs_embeds, inputs_embeds_masks) if image_input is not None: + inputs_embeds = get_inputs_embeds(input_ids, + self.model.get_input_embeddings, + inputs_embeds, inputs_embeds_masks) vision_embeddings = self._process_image_input(image_input) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.image_token_id) + inputs_embeds_masks = torch.ones_like(input_ids, dtype=torch.bool) input_ids = None - inputs_embeds_masks.fill_(1) hidden_states = self.model(input_ids, positions, diff --git a/vllm/sequence.py b/vllm/sequence.py index 58c09851bc971..ca6c80ffbfc7a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -201,6 +201,7 @@ def prompt_embeds(self) -> Optional[torch.Tensor]: def prompt_embeds(self, new_prompt_embeds: Optional[torch.Tensor]) -> None: self._prompt_embeds = new_prompt_embeds + @property def prompt_token_ids_array(self) -> array: """Return the prompt token ids in array type. From a72930fd596b79bffbd7225caa67d9b5236fc894 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Fri, 6 Sep 2024 08:58:34 -0500 Subject: [PATCH 17/88] fix: ultravox --- vllm/model_executor/models/ultravox.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 416fabda831a2..1b7447806d19f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -426,9 +426,11 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, audio_embeddings, _AUDIO_PLACEHOLDER_TOKEN) + inputs_embeds_masks = torch.ones_like(input_ids, dtype=torch.bool) input_ids = None else: inputs_embeds = None + inputs_embeds_masks = None hidden_states = self.language_model.model( input_ids=input_ids, @@ -436,7 +438,8 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, From 5e3eec9bcb7679805dca803384ea686900577658 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Fri, 6 Sep 2024 09:00:05 -0500 Subject: [PATCH 18/88] refactor: rename parameter --- vllm/model_executor/models/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 6c70e95e7e8a6..dfe0e8c7a93d9 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -284,7 +284,7 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: def get_inputs_embeds( input_ids: torch.Tensor, - embeddings: torch.nn.Module, + embeddings_module: torch.nn.Module, inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -293,10 +293,10 @@ def get_inputs_embeds( if all(inputs_embeds_masks): hidden_states = inputs_embeds else: - hidden_states = embeddings(input_ids) + hidden_states = embeddings_module(input_ids) hidden_states[inputs_embeds_masks] = inputs_embeds else: - hidden_states = embeddings(input_ids) + hidden_states = embeddings_module(input_ids) return hidden_states From 2b390263459622d8d7e93586b2c974d8de9d7614 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Wed, 11 Sep 2024 13:19:06 -0500 Subject: [PATCH 19/88] refactor: supports_input_embeds --- vllm/model_executor/models/interfaces.py | 9 +++++++++ vllm/model_executor/models/utils.py | 2 +- vllm/worker/model_runner.py | 9 ++++----- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 069948f812253..e5d1797fd57a2 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,6 +1,8 @@ +import inspect from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, Union, overload, runtime_checkable) +import torch.nn as nn from typing_extensions import TypeIs from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig @@ -189,3 +191,10 @@ def has_inner_state( return isinstance(model, _HasInnerStateType) return isinstance(model, HasInnerState) + + +def supports_input_embeds(model: nn.Module) -> bool: + """Check if the model supports input_embeds and input_embeds_masks.""" + model_forward_params = inspect.signature(model.forward).parameters + return ("inputs_embeds" in model_forward_params + and "inputs_embeds_masks" in model_forward_params) \ No newline at end of file diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index dfe0e8c7a93d9..245d2d273c353 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -290,7 +290,7 @@ def get_inputs_embeds( ) -> torch.Tensor: """Get the input embeddings from either `input_ids` and `inputs_embeds`.""" if inputs_embeds is not None: - if all(inputs_embeds_masks): + if inputs_embeds_masks.all().item(): hidden_states = inputs_embeds else: hidden_states = embeddings_module(input_ids) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ca42c0506bb3c..1394cbd9f2729 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -34,7 +34,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_executor.models.interfaces import (supports_lora, +from vllm.model_executor.models.interfaces import (supports_input_embeds, + supports_lora, supports_multimodal) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, @@ -1039,10 +1040,8 @@ def load_model(self) -> None: parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, cache_config=self.cache_config) - model_forward_params = inspect.signature(self.model.forward).parameters - if ("inputs_embeds" in model_forward_params - and "inputs_embeds_masks" in model_forward_params): - self.model_supports_input_embeds = True + self.model_supports_input_embeds = supports_input_embeds( + self.model) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) From 49fe3f7404bb98bcef1b4bdccfc870c3b371ca74 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Thu, 12 Sep 2024 16:08:16 -0500 Subject: [PATCH 20/88] feat: inputs_embeds for new models --- vllm/model_executor/models/exaone.py | 17 +++++++++++------ vllm/model_executor/models/granite.py | 17 +++++++++++------ vllm/model_executor/models/phimoe.py | 12 ++++++++++-- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 4a1c367de3f62..985f8a2a9e5a9 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -54,7 +54,8 @@ from vllm.utils import is_hip from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, + make_layers) class ExaoneGatedMLP(nn.Module): @@ -365,12 +366,13 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -484,9 +486,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return model_output def compute_logits( diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index b0325e8b616c8..a35851ef6e958 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -52,7 +52,8 @@ from vllm.utils import is_hip from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, + make_layers) class GraniteMLP(nn.Module): @@ -304,12 +305,13 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -418,9 +420,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return model_output def compute_logits( diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 25bc0590c745c..7170ff74a5110 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -47,6 +47,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA +from .utils import get_inputs_embeds class PhiMoEConfig(PretrainedConfig): @@ -462,8 +463,12 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) + residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -540,9 +545,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, inputs_embeds, + inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, From 56b9ac5f64e36c561c606bee6252873fb0e861ce Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 03:50:47 +0000 Subject: [PATCH 21/88] Fix typing --- vllm/model_executor/models/idefics2_vision_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index cc448ed28d2dc..12b57b182a310 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -287,7 +287,7 @@ def forward( self, pixel_values, patch_attention_mask: Optional[torch.BoolTensor] = None, - ) -> torch.tensor: + ) -> torch.Tensor: hidden_states = self.embeddings( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) From decb8ab5a0f76fe9d05c3eb635e86134989a3240 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 04:07:00 +0000 Subject: [PATCH 22/88] Support embeds in minicpm --- vllm/model_executor/models/minicpm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 7cb10fc1e9577..1dd5099b6b8a9 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -488,7 +488,9 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks) return hidden_states def compute_logits( From ee41cb70f019f819c84b4f90946270c688c86ba8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 04:07:18 +0000 Subject: [PATCH 23/88] Fix typing 2 --- vllm/model_executor/models/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 245d2d273c353..7b1c3a1407166 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,5 +1,5 @@ -from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, - Union, overload) +from typing import (Callable, Dict, Iterable, List, Literal, Optional, Protocol, + Tuple, Union, overload) import torch import torch.nn as nn @@ -284,12 +284,14 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: def get_inputs_embeds( input_ids: torch.Tensor, - embeddings_module: torch.nn.Module, + embeddings_module: Callable[[torch.Tensor], torch.Tensor], inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Get the input embeddings from either `input_ids` and `inputs_embeds`.""" if inputs_embeds is not None: + assert inputs_embeds_masks is not None + if inputs_embeds_masks.all().item(): hidden_states = inputs_embeds else: @@ -297,6 +299,7 @@ def get_inputs_embeds( hidden_states[inputs_embeds_masks] = inputs_embeds else: hidden_states = embeddings_module(input_ids) + return hidden_states From fd58d4bad93357b91e0bf5b97d04f77879880b57 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 04:21:25 +0000 Subject: [PATCH 24/88] Disable `inputs_embeds` for multimodal models as it conflicts with multimodal embeds --- vllm/model_executor/models/fuyu.py | 14 +++++++------- vllm/model_executor/models/phi3v.py | 14 ++++++-------- vllm/model_executor/models/ultravox.py | 5 +---- vllm/model_executor/models/utils.py | 10 ++++++---- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index a2fd67e79b018..3e3fae62ce92f 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -43,7 +43,7 @@ SequenceData) from .interfaces import SupportsMultiModal -from .utils import get_inputs_embeds, merge_multimodal_embeddings +from .utils import merge_multimodal_embeddings logger = init_logger(__name__) @@ -278,27 +278,27 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - inputs_embeds_masks: Optional[torch.Tensor] = None, **kwargs: object, ): image_input = self._parse_and_validate_image_input(**kwargs) - inputs_embeds = get_inputs_embeds( - input_ids, self.language_model.model.embed_tokens, inputs_embeds, - inputs_embeds_masks) + if image_input is not None: vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.embed_tokens(input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.image_token_id) + else: + inputs_embeds = None + hidden_states = self.language_model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, - inputs_embeds_masks=inputs_embeds_masks) + ) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 2123438782929..94de0c5ae9edf 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -44,7 +44,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsMultiModal -from .utils import flatten_bn, get_inputs_embeds, merge_multimodal_embeddings +from .utils import flatten_bn, merge_multimodal_embeddings logger = init_logger(__name__) @@ -621,25 +621,23 @@ def forward(self, inputs_embeds_masks: Optional[torch.Tensor] = None, **kwargs: object): image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is not None: - inputs_embeds = get_inputs_embeds(input_ids, - self.model.get_input_embeddings, - inputs_embeds, - inputs_embeds_masks) vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.image_token_id) - inputs_embeds_masks = torch.ones_like(input_ids, dtype=torch.bool) input_ids = None + else: + inputs_embeds = None hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - inputs_embeds=inputs_embeds, - inputs_embeds_masks=inputs_embeds_masks) + inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 1b7447806d19f..416fabda831a2 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -426,11 +426,9 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, audio_embeddings, _AUDIO_PLACEHOLDER_TOKEN) - inputs_embeds_masks = torch.ones_like(input_ids, dtype=torch.bool) input_ids = None else: inputs_embeds = None - inputs_embeds_masks = None hidden_states = self.language_model.model( input_ids=input_ids, @@ -438,8 +436,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - inputs_embeds_masks=inputs_embeds_masks) + inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 7b1c3a1407166..44312883225e9 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -283,21 +283,23 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: def get_inputs_embeds( - input_ids: torch.Tensor, + input_ids: Optional[torch.Tensor], embeddings_module: Callable[[torch.Tensor], torch.Tensor], inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Get the input embeddings from either `input_ids` and `inputs_embeds`.""" if inputs_embeds is not None: - assert inputs_embeds_masks is not None - - if inputs_embeds_masks.all().item(): + if inputs_embeds_masks is None or inputs_embeds_masks.all().item(): hidden_states = inputs_embeds else: + assert input_ids is not None, "inputs_embeds should not be masked out for multimodal models" + hidden_states = embeddings_module(input_ids) hidden_states[inputs_embeds_masks] = inputs_embeds else: + assert input_ids is not None, "inputs_embeds should be set for multimodal models" + hidden_states = embeddings_module(input_ids) return hidden_states From 5849509707a25b37cb1e1d9518ef08e96592f0b5 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 04:27:49 +0000 Subject: [PATCH 25/88] Reformat --- vllm/model_executor/models/interfaces.py | 2 +- vllm/model_executor/models/minicpm.py | 3 +-- vllm/model_executor/models/utils.py | 10 ++++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index e5d1797fd57a2..d022a0783fd83 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -197,4 +197,4 @@ def supports_input_embeds(model: nn.Module) -> bool: """Check if the model supports input_embeds and input_embeds_masks.""" model_forward_params = inspect.signature(model.forward).parameters return ("inputs_embeds" in model_forward_params - and "inputs_embeds_masks" in model_forward_params) \ No newline at end of file + and "inputs_embeds_masks" in model_forward_params) diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 1dd5099b6b8a9..8f20943e3a4a9 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -489,8 +489,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - inputs_embeds=inputs_embeds, - inputs_embeds_masks=inputs_embeds_masks) + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 44312883225e9..18b83165fc3f5 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,5 +1,5 @@ -from typing import (Callable, Dict, Iterable, List, Literal, Optional, Protocol, - Tuple, Union, overload) +from typing import (Callable, Dict, Iterable, List, Literal, Optional, + Protocol, Tuple, Union, overload) import torch import torch.nn as nn @@ -293,12 +293,14 @@ def get_inputs_embeds( if inputs_embeds_masks is None or inputs_embeds_masks.all().item(): hidden_states = inputs_embeds else: - assert input_ids is not None, "inputs_embeds should not be masked out for multimodal models" + msg = "inputs_embeds should not be masked out for multimodal models" + assert input_ids is not None, msg hidden_states = embeddings_module(input_ids) hidden_states[inputs_embeds_masks] = inputs_embeds else: - assert input_ids is not None, "inputs_embeds should be set for multimodal models" + msg = "inputs_embeds should be set for multimodal models" + assert input_ids is not None, msg hidden_states = embeddings_module(input_ids) From 23a487648a30e4c765701b4d4b1d128a838a2463 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 04:35:14 +0000 Subject: [PATCH 26/88] Optimize --- vllm/worker/model_runner.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c467918c43bed..49e81fa3e8f3a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -496,12 +496,18 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. if inter_data.is_prompt: if seq_data.prompt_embeds is None: - tokens = seq_data.get_token_ids()[context_len:seq_len] + tokens = seq_data.get_token_ids() + if context_len != 0 or seq_len < len(tokens): + tokens = tokens[context_len:seq_len] + input_embeds_mask = torch.zeros(seq_len - context_len, dtype=torch.bool) else: tokens = [0] * seq_len - input_embeds = seq_data.prompt_embeds[context_len:seq_len] + + if context_len != 0 or seq_len < len(seq_data.prompt_embeds): + input_embeds = seq_data.prompt_embeds[context_len:seq_len] + input_embeds_mask = torch.ones(seq_len - context_len, dtype=torch.bool) else: @@ -811,20 +817,20 @@ def build(self) -> ModelInputForGPU: max_encoder_seq_len = max(max_encoder_seq_len, inter_data.encoder_seq_len) - input_embeds = [ + input_embeds_lst = [ inter_data.input_embeds for inter_data in self.inter_data_list if inter_data.input_embeds is not None ] or None - input_embeds_masks = [ + input_embeds_masks_lst = [ inter_data.input_embeds_mask for inter_data in self.inter_data_list if inter_data.input_embeds_mask is not None ] or None - if input_embeds: - input_embeds = torch.cat(input_embeds).to( + if input_embeds_lst: + input_embeds = torch.cat(input_embeds_lst).to( device=self.runner.device, dtype=self.runner.model_config.dtype) - if input_embeds_masks: - input_embeds_masks = torch.cat(input_embeds_masks).to( + if input_embeds_masks_lst: + input_embeds_masks = torch.cat(input_embeds_masks_lst).to( self.runner.device) # Mapping from request IDs to sequence IDs. Used for Jamba models From fbf3f1099d70a0857865dfc59a9157e620066718 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 04:42:37 +0000 Subject: [PATCH 27/88] Fix unbound variables --- vllm/worker/model_runner.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 49e81fa3e8f3a..393c53140c87d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -505,8 +505,9 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, else: tokens = [0] * seq_len - if context_len != 0 or seq_len < len(seq_data.prompt_embeds): - input_embeds = seq_data.prompt_embeds[context_len:seq_len] + input_embeds = seq_data.prompt_embeds + if context_len != 0 or seq_len < len(input_embeds): + input_embeds = input_embeds[context_len:seq_len] input_embeds_mask = torch.ones(seq_len - context_len, dtype=torch.bool) @@ -820,18 +821,23 @@ def build(self) -> ModelInputForGPU: input_embeds_lst = [ inter_data.input_embeds for inter_data in self.inter_data_list if inter_data.input_embeds is not None - ] or None - input_embeds_masks_lst = [ - inter_data.input_embeds_mask for inter_data in self.inter_data_list - if inter_data.input_embeds_mask is not None - ] or None + ] if input_embeds_lst: input_embeds = torch.cat(input_embeds_lst).to( device=self.runner.device, dtype=self.runner.model_config.dtype) + else: + input_embeds = None + + input_embeds_masks_lst = [ + inter_data.input_embeds_mask for inter_data in self.inter_data_list + if inter_data.input_embeds_mask is not None + ] if input_embeds_masks_lst: input_embeds_masks = torch.cat(input_embeds_masks_lst).to( self.runner.device) + else: + input_embeds_masks = None # Mapping from request IDs to sequence IDs. Used for Jamba models # that manages the cache by itself. From 9822652c436bf002b5d70822cb28a18f2f6f964d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 06:42:32 +0000 Subject: [PATCH 28/88] Cleanup --- vllm/utils.py | 30 +++++++++++++++++++++++++++--- vllm/worker/model_runner.py | 21 +++++++++------------ 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 060b387ec7834..d984360d7dfe9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -18,8 +18,8 @@ from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, - Hashable, List, Literal, Optional, OrderedDict, Set, Tuple, - Type, TypeVar, Union, overload) + Hashable, List, Literal, Optional, OrderedDict, Protocol, + Set, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 import numpy as np @@ -29,7 +29,7 @@ import torch.types import yaml from packaging.version import Version -from typing_extensions import ParamSpec, TypeIs, assert_never +from typing_extensions import ParamSpec, Self, TypeIs, assert_never import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger @@ -922,6 +922,30 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]: return [item for sublist in lists for item in sublist] +class _SequenceLike(Protocol): + + def __len__(self) -> int: + ... + + def __getitem__(self, __idx: slice) -> Self: + ... + + +_S = TypeVar("_S", bound=_SequenceLike) + + +def maybe_slice(lst: _S, start: int, stop: int) -> _S: + """ + Slice a sequence, returning the original one where possible. + + This avoids unnecessary copying of the sequence. + """ + if start != 0 or stop < len(lst): + return lst[start:stop] + + return lst + + def init_cached_hf_modules() -> None: """ Lazy initialization of the Hugging Face modules. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 393c53140c87d..0067381dfa2ae 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, - supports_dynamo) + maybe_slice, supports_dynamo) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -491,24 +491,18 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_len - 1 seq_len = min(seq_len, context_len + token_chunk_size) - input_embeds = None - input_embeds_mask = None # Compute tokens. if inter_data.is_prompt: if seq_data.prompt_embeds is None: - tokens = seq_data.get_token_ids() - if context_len != 0 or seq_len < len(tokens): - tokens = tokens[context_len:seq_len] - + tokens = maybe_slice(seq_data.get_token_ids(), context_len, + seq_len) + input_embeds = None input_embeds_mask = torch.zeros(seq_len - context_len, dtype=torch.bool) else: tokens = [0] * seq_len - - input_embeds = seq_data.prompt_embeds - if context_len != 0 or seq_len < len(input_embeds): - input_embeds = input_embeds[context_len:seq_len] - + input_embeds = maybe_slice(seq_data.prompt_embeds, context_len, + seq_len) input_embeds_mask = torch.ones(seq_len - context_len, dtype=torch.bool) else: @@ -516,6 +510,9 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # tokens. tokens = seq_data.get_last_token_id() + input_embeds = None + input_embeds_mask = None + inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len From 53962c47ab3815c14cc737ae141064843bb1811b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 06:43:29 +0000 Subject: [PATCH 29/88] Cleanup 2 --- vllm/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index d984360d7dfe9..ec7a96f5404d2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -934,16 +934,16 @@ def __getitem__(self, __idx: slice) -> Self: _S = TypeVar("_S", bound=_SequenceLike) -def maybe_slice(lst: _S, start: int, stop: int) -> _S: +def maybe_slice(seq: _S, start: int, stop: int) -> _S: """ Slice a sequence, returning the original one where possible. This avoids unnecessary copying of the sequence. """ - if start != 0 or stop < len(lst): - return lst[start:stop] + if start != 0 or stop < len(seq): + return seq[start:stop] - return lst + return seq def init_cached_hf_modules() -> None: From b8137aa2448005dff481659e3e1dfd3645aca68a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 10:44:36 +0000 Subject: [PATCH 30/88] Fix unbound `prompt_embeds` in validation and clean it up --- vllm/engine/llm_engine.py | 54 ++++++++++++++++++++++++++------------- vllm/inputs/data.py | 6 ++--- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 863a3c1a483be..f69cafb948f7d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -625,6 +625,7 @@ def _add_processed_request( trace_headers: Optional[Mapping[str, str]] = None, ) -> None: self._validate_model_inputs(processed_inputs) + # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) @@ -678,17 +679,6 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - def _support_prompt_embeds(self) -> Tuple[bool, str]: - if self.speculative_config is not None: - return False, "Speculative decoding does not support prompt_embeds." - driver_worker = self.model_executor.driver_worker - model_runner = driver_worker.worker.model_runner if isinstance( - driver_worker, WorkerWrapperBase) else driver_worker.model_runner - if model_runner.model_supports_input_embeds: - return True, "" - return False, (f"Model {self.model_config.model} does not support " - "input embeddings, but prompt_embeds was provided.") - def add_request( self, request_id: str, @@ -753,10 +743,6 @@ def add_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) - if preprocessed_inputs.get("prompt_embeds") is not None: - support_prompt_embeds, error_msg = self._support_prompt_embeds() - if not support_prompt_embeds: - raise ValueError(error_msg) processed_inputs = self.input_processor(preprocessed_inputs) self._add_processed_request( @@ -1712,16 +1698,40 @@ def is_encoder_decoder_model(self): def is_embedding_model(self): return self.model_config.is_embedding_model + def _support_prompt_embeds(self) -> Tuple[bool, str]: + if self.speculative_config is not None: + return False, "Speculative decoding does not support prompt_embeds." + driver_worker = self.model_executor.driver_worker + model_runner = driver_worker.worker.model_runner if isinstance( + driver_worker, WorkerWrapperBase) else driver_worker.model_runner + if model_runner.model_supports_input_embeds: + return True, "" + return False, (f"Model {self.model_config.model} does not support " + "input embeddings, but prompt_embeds was provided.") + def _validate_model_inputs(self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs]): if self.is_encoder_decoder_model(): prompt_ids = inputs.get("encoder_prompt_token_ids") else: prompt_ids = inputs.get("prompt_token_ids") - prompt_embeds = inputs.get("prompt_embeds") - if (prompt_ids is None - or len(prompt_ids) == 0) and prompt_embeds is None: + prompt_embeds = inputs.get("prompt_embeds") + + if prompt_ids is None: + if prompt_embeds is None: + raise ValueError("You must provide a prompt") + + self._validate_prompt_embeds(prompt_embeds) + else: + if prompt_embeds is None: + self._validate_prompt_ids(prompt_ids) + else: + raise ValueError("You can only provide either tokens or " + "embeddings, not both") + + def _validate_prompt_ids(self, prompt_ids: List[int]): + if len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") if self.model_config.is_multimodal_model: @@ -1739,3 +1749,11 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs, # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def _validate_prompt_embeds(self, prompt_embeds: torch.Tensor): + if len(prompt_embeds) == 0: + raise ValueError("Prompt cannot be empty") + + support_prompt_embeds, error_msg = self._support_prompt_embeds() + if not support_prompt_embeds: + raise ValueError(error_msg) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index b3c49076be3bb..e6037c14ff9f5 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -81,9 +81,9 @@ class EmbedsPrompt(TypedDict): # TODO: Make fields ReadOnly once mypy supports it class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): - """Represents an encoder/decoder model input prompt, - comprising an explicit encoder prompt and a - decoder prompt. + """ + Represents an encoder/decoder model input prompt, + comprising an explicit encoder prompt and a decoder prompt. The encoder and decoder prompts, respectively, may formatted according to any of the From 2cf3b4bec0f750b6bd70fcf83b415a1a3c6bdaf0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 10:46:02 +0000 Subject: [PATCH 31/88] Indent --- vllm/engine/llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f69cafb948f7d..7ec03db7c7687 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1721,8 +1721,8 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs, if prompt_ids is None: if prompt_embeds is None: raise ValueError("You must provide a prompt") - - self._validate_prompt_embeds(prompt_embeds) + else: + self._validate_prompt_embeds(prompt_embeds) else: if prompt_embeds is None: self._validate_prompt_ids(prompt_ids) From 716b64ad9bc832a9d427bca273bf49a3b507b899 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 19 Sep 2024 12:37:11 +0000 Subject: [PATCH 32/88] Fix unbound local and cleanup 2 --- vllm/inputs/preprocess.py | 140 +++++++++++++++++++++++--------------- 1 file changed, 87 insertions(+), 53 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a38c7dcb22ebe..5117ca5a66b4e 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,5 +1,6 @@ import asyncio -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Union import torch from typing_extensions import assert_never @@ -19,11 +20,21 @@ logger = init_logger(__name__) -PromptComponents = Tuple[Optional[str], List[int], Optional[torch.Tensor], - Optional["MultiModalDataDict"]] -DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional[torch.Tensor], - Optional["MultiModalDataDict"]] + +@dataclass(frozen=True) +class PromptComponents: + prompt: Optional[str] + prompt_token_ids: List[int] + prompt_embeds: Optional[torch.Tensor] = None + multi_modal_data: Optional["MultiModalDataDict"] = None + + +@dataclass(frozen=True) +class DecoderPromptComponents: + prompt: Optional[str] + prompt_token_ids: Optional[List[int]] + prompt_embeds: Optional[torch.Tensor] = None + multi_modal_data: Optional["MultiModalDataDict"] = None class InputPreprocessor: @@ -237,13 +248,19 @@ def _extract_prompt_components( request_id=request_id, lora_request=lora_request, ) - multi_modal_data = None - prompt_embeds = None - elif parsed["type"] == "tokens": - prompt = None + + return PromptComponents(prompt=prompt, + prompt_token_ids=prompt_token_ids) + + if parsed["type"] == "tokens": prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") - elif parsed["type"] == "text": + + return PromptComponents(prompt=None, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data) + + if parsed["type"] == "text": prompt = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( prompt, @@ -251,16 +268,21 @@ def _extract_prompt_components( lora_request=lora_request, ) multi_modal_data = parsed["content"].get("multi_modal_data") - prompt_embeds = None - elif parsed["type"] == "embeds": - prompt = None - prompt_token_ids = [] + + return PromptComponents(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data) + + if parsed["type"] == "embeds": prompt_embeds = parsed["content"]["prompt_embeds"] multi_modal_data = parsed["content"].get("multi_modal_data") - else: - assert_never(parsed) - return prompt, prompt_token_ids, prompt_embeds, multi_modal_data + return PromptComponents(prompt=None, + prompt_token_ids=[], + multi_modal_data=multi_modal_data, + prompt_embeds=prompt_embeds) + + assert_never(parsed) async def _extract_prompt_components_async( self, @@ -278,14 +300,19 @@ async def _extract_prompt_components_async( request_id=request_id, lora_request=lora_request, ) - multi_modal_data = None - prompt_embeds = None - elif parsed["type"] == "tokens": - prompt = None + + return PromptComponents(prompt=prompt, + prompt_token_ids=prompt_token_ids) + + if parsed["type"] == "tokens": prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") - prompt_embeds = None - elif parsed["type"] == "text": + + return PromptComponents(prompt=None, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data) + + if parsed["type"] == "text": prompt = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( prompt, @@ -293,37 +320,40 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) multi_modal_data = parsed["content"].get("multi_modal_data") - prompt_embeds = None - elif parsed["type"] == "embeds": - prompt = None - prompt_token_ids = [] + + return PromptComponents(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data) + + if parsed["type"] == "embeds": prompt_embeds = parsed["content"]["prompt_embeds"] multi_modal_data = parsed["content"].get("multi_modal_data") - else: - assert_never(parsed) - return prompt, prompt_token_ids, prompt_embeds, multi_modal_data + return PromptComponents(prompt=None, + prompt_token_ids=[], + multi_modal_data=multi_modal_data, + prompt_embeds=prompt_embeds) + + assert_never(parsed) def _build_enc_dec_llm_inputs( self, encoder_comps: PromptComponents, - decoder_comps: DecoderPromptComponents, + decoder_comps: Union[PromptComponents, DecoderPromptComponents], ) -> EncoderDecoderLLMInputs: - encoder_prompt, encoder_prompt_ids, _, encoder_mm_data = encoder_comps - decoder_prompt, decoder_prompt_ids, _, decoder_mm_data = decoder_comps - - if encoder_mm_data is not None or decoder_mm_data is not None: + if (encoder_comps.multi_modal_data is not None + or decoder_comps.multi_modal_data is not None): raise ValueError("Multi-modal encoder-decoder models are " "not supported yet") - decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + decoder_prompt_ids = self._prepare_decoder_input_ids_for_generation( + decoder_comps.prompt_token_ids) return EncoderDecoderLLMInputs( prompt_token_ids=decoder_prompt_ids, - prompt=decoder_prompt, - encoder_prompt_token_ids=encoder_prompt_ids, - encoder_prompt=encoder_prompt, + prompt=decoder_comps.prompt, + encoder_prompt_token_ids=encoder_comps.prompt_token_ids, + encoder_prompt=encoder_comps.prompt, ) def _process_encoder_decoder_prompt( @@ -365,7 +395,7 @@ def _process_encoder_decoder_prompt( ''' encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents + decoder_comps: Union[PromptComponents, DecoderPromptComponents] if is_explicit_encoder_decoder_prompt(inputs): encoder_comps = self._extract_prompt_components( @@ -374,7 +404,8 @@ def _process_encoder_decoder_prompt( ) if (decoder_input := inputs["decoder_prompt"]) is None: - decoder_comps = None, None, None, None + decoder_comps = DecoderPromptComponents(prompt=None, + prompt_token_ids=None) else: decoder_comps = self._extract_prompt_components( decoder_input, @@ -386,7 +417,8 @@ def _process_encoder_decoder_prompt( request_id=request_id, ) - decoder_comps = None, None, None, None + decoder_comps = DecoderPromptComponents(prompt=None, + prompt_token_ids=None) return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) @@ -397,7 +429,7 @@ async def _process_encoder_decoder_prompt_async( ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents + decoder_comps: Union[PromptComponents, DecoderPromptComponents] if is_explicit_encoder_decoder_prompt(inputs): encoder_task = self._extract_prompt_components_async( @@ -407,7 +439,8 @@ async def _process_encoder_decoder_prompt_async( if (decoder_input := inputs["decoder_prompt"]) is None: encoder_comps = await encoder_task - decoder_comps = None, None, None, None + decoder_comps = DecoderPromptComponents(prompt=None, + prompt_token_ids=None) else: decoder_task = self._extract_prompt_components_async( decoder_input, @@ -422,7 +455,8 @@ async def _process_encoder_decoder_prompt_async( request_id=request_id, ) - decoder_comps = None, None, None, None + decoder_comps = DecoderPromptComponents(prompt=None, + prompt_token_ids=None) return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) @@ -431,15 +465,15 @@ def _build_decoder_only_llm_inputs( prompt_comps: PromptComponents, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> LLMInputs: - prompt, prompt_token_ids, prompt_embeds, multi_modal_data = prompt_comps - prompt_token_ids = self._apply_prompt_adapter( - prompt_token_ids, prompt_adapter_request=prompt_adapter_request) + prompt_comps.prompt_token_ids, + prompt_adapter_request=prompt_adapter_request, + ) return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - prompt_embeds=prompt_embeds, - multi_modal_data=multi_modal_data) + prompt=prompt_comps.prompt, + prompt_embeds=prompt_comps.prompt_embeds, + multi_modal_data=prompt_comps.multi_modal_data) def _process_decoder_only_prompt( self, From 7dd3d86737a2d436b7d99b461582dedc0cdc1c07 Mon Sep 17 00:00:00 2001 From: Nan2018 Date: Thu, 19 Sep 2024 11:56:58 -0500 Subject: [PATCH 33/88] fix: gemma --- vllm/model_executor/models/gemma.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 1034225c15701..44cad1cf7be61 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -350,9 +350,12 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds, - inputs_embeds_masks) + hidden_states = self.model(input_ids, + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks) return hidden_states def compute_logits( From 805901402941f6dd7b94f3f6b44a60b5ced2c747 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 14:26:39 +0000 Subject: [PATCH 34/88] Have two distinct `SequenceData` classes, one for tokens and one for embeddings --- tests/samplers/test_sampler.py | 30 +-- tests/spec_decode/utils.py | 14 +- tests/test_logits_processor.py | 9 +- tests/test_sequence.py | 7 +- .../test_encoder_decoder_model_runner.py | 23 +- tests/worker/test_model_runner.py | 42 ++-- vllm/inputs/data.py | 4 +- vllm/inputs/registry.py | 14 +- vllm/sequence.py | 231 ++++++++++++------ 9 files changed, 215 insertions(+), 159 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 19a5ca5e27502..eed47b57b5f68 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,6 +1,5 @@ import itertools import random -from array import array from typing import Dict, List, Optional, Tuple from unittest.mock import Mock, patch @@ -12,8 +11,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import (SamplingParams, SequenceTokenData, + SequenceGroupMetadata) from vllm.utils import Counter, is_pin_memory_available @@ -59,9 +58,7 @@ def _do_sample( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceTokenData.from_seq([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -205,9 +202,8 @@ def create_sampling_params(min_tokens, return sampling_params def create_sequence_data(num_input=3, num_generated=0): - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, - random.choices(range(0, VOCAB_SIZE), k=num_input))) + seq_data = SequenceTokenData.from_seq( + random.choices(range(0, VOCAB_SIZE), k=num_input)) if num_generated > 0: seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), k=num_generated) @@ -238,7 +234,7 @@ def generate_test_case(): eos_token_id=eos_token_id, stop_token_ids=stop_token_ids) - seq_data: Dict[int, SequenceData] = {} + seq_data: Dict[int, SequenceTokenData] = {} seq_group_penalization: List[bool] = [] for _ in range(num_seqs): num_input = random.randint(1, 100) @@ -511,9 +507,7 @@ def test_sampler_mixed(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceTokenData.from_seq([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -613,9 +607,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceTokenData.from_seq([1, 2, 3])}, sampling_params=SamplingParams( temperature=1, top_k=top_k, @@ -699,11 +691,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: - SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - [1, 2, 3])) - }, + seq_data={0: SequenceTokenData.from_seq([1, 2, 3])}, sampling_params=sampling_params[i], block_tables={0: [1]}, )) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 9075a433eb66e..e18517dc2e4bd 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,4 +1,3 @@ -from array import array from itertools import count from typing import Callable, Dict, List, Optional from typing import Sequence as GenericSequence @@ -11,9 +10,9 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, - CompletionSequenceGroupOutput, Logprob, - SequenceData, SequenceGroupMetadata, SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceTokenData, SequenceGroupMetadata, + SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner @@ -139,11 +138,8 @@ def create_seq_group_metadata_from_prompts( is_prompt=len(cont_token_ids) == 0, seq_data={ i: - SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]), - _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, - cont_token_ids[:]), - ), + SequenceTokenData.from_seq(prompt_token_ids[:], + cont_token_ids[:]), }, sampling_params=SamplingParams(temperature=0.0, ), block_tables={i: block_allocations[i][:]}, diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 1ce49a50688ae..3cbd5e7bacb0e 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -1,5 +1,4 @@ import random -from array import array from typing import Tuple from unittest.mock import patch @@ -9,8 +8,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import (SamplingParams, SequenceTokenData, + SequenceGroupMetadata) from vllm.utils import is_pin_memory_available @@ -71,9 +70,7 @@ def pick_ith(token_ids, logits): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceTokenData.from_seq([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 348ba7dd41d99..b5515417da50a 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,10 +1,7 @@ -from array import array - import pytest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, - CompletionSequenceGroupOutput, SequenceData, +from vllm.sequence import (CompletionSequenceGroupOutput, SequenceTokenData, SequenceOutput) from .core.utils import create_dummy_prompt @@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs): def test_sequence_data_prefill(): - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4])) + seq_data = SequenceTokenData.from_seq([1, 2, 3, 4]) assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_computed_tokens() == 0 # advance by 2 diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 27cdf5f339ede..2dc858a66b7ca 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,13 +1,12 @@ import itertools -from array import array from typing import List import pytest import torch from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import (SamplingParams, SequenceTokenData, + SequenceGroupMetadata) from vllm.utils import is_cpu, make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import _get_graph_batch_size @@ -119,12 +118,10 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - range(seq_len))) + seq_data = SequenceTokenData.from_seq(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len))) + encoder_seq_data = SequenceTokenData.from_seq(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -317,11 +314,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) + seq_data = SequenceTokenData.from_seq(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + encoder_seq_data = SequenceTokenData.from_seq(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -523,11 +518,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) + seq_data = SequenceTokenData.from_seq(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + encoder_seq_data = SequenceTokenData.from_seq(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f58b5a5e98998..7478cba5f55df 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,5 +1,4 @@ import random -from array import array from typing import List import pytest @@ -9,8 +8,8 @@ init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import (SamplingParams, SequenceTokenData, + SequenceGroupMetadata) from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -52,13 +51,13 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio): seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, range(seq_len)), - torch.rand(seq_len, 10)) + seq_data = SequenceTokenData.from_seq( + range(seq_len), + torch.rand(seq_len, 10).tolist(), + ) input_embeds_len += seq_len else: - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, range(seq_len))) + seq_data = SequenceTokenData.from_seq(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -185,12 +184,13 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, range(0)), - torch.rand(context_len, 10)) + seq_data = SequenceTokenData.from_seq( + [], + torch.rand(context_len, 10).tolist(), + ) input_embeds_len += context_len else: - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))) + seq_data = SequenceTokenData.from_seq(range(context_len)) seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. seq_data.append_token_id(1, 0) @@ -364,12 +364,13 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, range(0)), - torch.rand(seq_len, 10)) + seq_data = SequenceTokenData.from_seq( + [], + torch.rand(seq_len, 10).tolist(), + ) input_embeds_len += seq_len else: - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, range(seq_len))) + seq_data = SequenceTokenData.from_seq(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -386,11 +387,12 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 if random.random() < prompt_embeds_ratio: - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, range(0)), - torch.rand(context_len, 10)) + seq_data = SequenceTokenData.from_seq( + [], + torch.rand(context_len, 10).tolist(), + ), else: - prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)) - seq_data = SequenceData(prompt_toks) + seq_data = SequenceTokenData.from_seq(range(context_len)) seq_data.append_token_id(1, 0) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index e6037c14ff9f5..3e233ad12a71e 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -123,7 +123,7 @@ class LLMInputs(TypedDict): This specifies the data required for decoder-only models. """ - prompt_token_ids: List[int] + prompt_token_ids: Optional[List[int]] """The token IDs of the prompt.""" prompt: NotRequired[Optional[str]] @@ -150,7 +150,7 @@ class EncoderDecoderLLMInputs(LLMInputs): This specifies the required data for encoder-decoder models. """ - encoder_prompt_token_ids: List[int] + encoder_prompt_token_ids: Optional[List[int]] """The token IDs of the encoder prompt.""" encoder_prompt: NotRequired[Optional[str]] diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ae6c6c05d9f72..f74b1ea4e5825 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,5 +1,4 @@ import functools -from array import array from collections import UserDict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, @@ -16,7 +15,7 @@ if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.multimodal import MultiModalDataDict, MultiModalRegistry - from vllm.sequence import SequenceData + from vllm.sequence import SequenceTokenData logger = init_logger(__name__) @@ -73,7 +72,7 @@ def __call__( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> Tuple["SequenceTokenData", Optional["MultiModalDataDict"]]: """ Create dummy data to be inputted into the model. @@ -119,7 +118,7 @@ def _default_dummy_data_factory( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> Tuple["SequenceTokenData", Optional["MultiModalDataDict"]]: """ The default dummy data factory represents the longest possible text that can be inputted to the model. @@ -128,10 +127,9 @@ def _default_dummy_data_factory( :data:`InputProcessor` is not applied to the dummy data. """ # Avoid circular import - from vllm.sequence import SequenceData + from vllm.sequence import SequenceTokenData - dummy_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len) + dummy_seq_data = SequenceTokenData.from_seq([0] * seq_len) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data @@ -163,7 +161,7 @@ def dummy_data_for_profiling( model_config: "ModelConfig", seq_len: int, mm_registry: "MultiModalRegistry", - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> Tuple["SequenceTokenData", Optional["MultiModalDataDict"]]: """ Create dummy data for profiling the memory usage of a model. diff --git a/vllm/sequence.py b/vllm/sequence.py index 669b3b5b88d2c..ee53655fac0c5 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,7 +5,8 @@ from array import array from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, + Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Union, cast @@ -133,33 +134,24 @@ class SequenceDataDelta( new_stage: SequenceStage -class SequenceData(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] +class SequenceDataMixin(msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] """Data associated with a sequence. Args: - prompt_token_ids: The token IDs of the prompt. - prompt_embeds: The embeddings of the prompt. output_token_ids: The token IDs of the output. Set to an empty list if None. Attributes: - prompt_token_ids: The token IDs of the prompt. - prompt_embeds: The embeddings of the prompt. output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. """ - # NOTE: we cannot use Union[List, array] because msgspec cannot support - # union of 2 list types. - _prompt_token_ids: array - _prompt_embeds: Optional[List[torch.Tensor]] = None _output_token_ids: array = msgspec.field( default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 - _prompt_token_ids_tuple: Tuple[int, - ...] = msgspec.field(default_factory=tuple) + # The number of tokens that are computed (that run against the model). _num_computed_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL @@ -173,47 +165,18 @@ class SequenceData(msgspec.Struct, _mrope_position_delta: Optional[int] = None def __post_init__(self) -> None: - assert self._prompt_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l" - self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( - self._prompt_token_ids) self._update_cached_all_tokens() - def _update_cached_all_tokens(self): - assert isinstance(self._prompt_token_ids, array) + def _update_cached_all_tokens(self) -> None: assert isinstance(self._output_token_ids, array) - self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + - self._output_token_ids) + + self._cached_all_token_ids = list(self._output_token_ids) @property def cumulative_logprob(self) -> float: return self._cumulative_logprob - @property - def prompt_token_ids(self) -> Tuple[int, ...]: - return self._prompt_token_ids_tuple - - @prompt_token_ids.setter - def prompt_token_ids(self, new_prompt_token_ids) -> None: - raise NotImplementedError - - @property - def prompt_embeds(self) -> Optional[torch.Tensor]: - return self._prompt_embeds - - @prompt_embeds.setter - def prompt_embeds(self, new_prompt_embeds: Optional[torch.Tensor]) -> None: - self._prompt_embeds = new_prompt_embeds - - @property - def prompt_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - return self._prompt_token_ids - @property def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @@ -249,13 +212,11 @@ def append_token_id(self, token_id: int, logprob: float) -> None: self._cumulative_logprob += logprob def get_len(self) -> int: - if self._prompt_embeds is None: - return len(self._output_token_ids) + len(self._prompt_token_ids) - else: - return len(self._output_token_ids) + len(self._prompt_embeds) + return self.get_prompt_len() + len(self._output_token_ids) + @abstractmethod def get_prompt_len(self) -> int: - return len(self._prompt_token_ids) + raise NotImplementedError def get_output_len(self) -> int: return len(self._output_token_ids) @@ -263,17 +224,6 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self._cached_all_token_ids - def get_prefix_token_ids( - self, num_tokens: int - ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: - """Get prefix tokens, and make the return value hashable""" - prompt_length = self.get_prompt_len() - if num_tokens > prompt_length: - return (self._prompt_token_ids_tuple, - tuple(self._output_token_ids[:num_tokens - prompt_length])) - else: - return (self._prompt_token_ids_tuple[:num_tokens], None) - def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens @@ -303,14 +253,6 @@ def get_num_uncomputed_tokens(self) -> int: # prefill for both prompt and output. return self.get_len() - self.get_num_computed_tokens() - def get_last_token_id(self) -> int: - if not self._output_token_ids: - return self._prompt_token_ids[-1] - return self._output_token_ids[-1] - - def get_prompt_token_ids(self) -> Tuple[int, ...]: - return self.prompt_token_ids - def get_output_token_ids(self) -> Tuple[int, ...]: return self.output_token_ids @@ -333,14 +275,152 @@ def apply_delta(self, delta: SequenceDataDelta): def stage(self) -> SequenceStage: return self._stage + +class _TokenBase(msgspec.Struct, omit_defaults=True): # type: ignore[call-arg] + # NOTE: we cannot use Union[List, array] because msgspec cannot support + # union of 2 list types. + _prompt_token_ids: array + + +# Note the ordering here so `_prompt_token_ids` is the first __init__ argument +class SequenceTokenData(SequenceDataMixin, _TokenBase, + omit_defaults=True): # type: ignore[call-arg] + """Data associated with a sequence. + + Args: + prompt_token_ids: The token IDs of the prompt. + output_token_ids: The token IDs of the output. Set to an empty list if + None. + + Attributes: + prompt_token_ids: The token IDs of the prompt. + output_token_ids: The token IDs of the output. + cumulative_logprob: The cumulative log probability of the output. + """ + + ### The below fields should not be passed as an argument ### + _prompt_token_ids_tuple: Tuple[int, + ...] = msgspec.field(default_factory=tuple) + + # For tagged union + type: Literal["tokens"] = "tokens" + + @staticmethod + def from_seq( + prompt_token_ids: GenericSequence[int], + output_token_ids: Optional[GenericSequence[int]] = None, + ) -> "SequenceTokenData": + prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + prompt_token_ids) + + if output_token_ids is None: + return SequenceTokenData(prompt_token_ids_arr) + + output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + output_token_ids) + + return SequenceTokenData(prompt_token_ids_arr, output_token_ids_arr) + + def __post_init__(self) -> None: + assert self._prompt_token_ids.typecode == "l" + assert self._output_token_ids.typecode == "l" + + self._prompt_token_ids_tuple = tuple(self._prompt_token_ids) + + self._update_cached_all_tokens() + + def _update_cached_all_tokens(self) -> None: + assert isinstance(self._prompt_token_ids, array) + assert isinstance(self._output_token_ids, array) + + self._cached_all_token_ids = list(self._prompt_token_ids + + self._output_token_ids) + + @property + def prompt_token_ids(self) -> Tuple[int, ...]: + return self._prompt_token_ids_tuple + + @property + def prompt_token_ids_array(self) -> array: + """Return the prompt token ids in array type. + + Note that the array is in "I" type, and it is not compatible + with torch.long (2 bytes vs 4 bytes). So beware of the usage. + """ + return self._prompt_token_ids + + def get_prompt_len(self) -> int: + return len(self._prompt_token_ids) + + def get_last_token_id(self) -> int: + if not self._output_token_ids: + return self._prompt_token_ids[-1] + return self._output_token_ids[-1] + + def get_prompt_token_ids(self) -> Tuple[int, ...]: + return self.prompt_token_ids + + def get_prefix_token_ids( + self, + num_tokens: int, + ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: + """Get prefix tokens, and make the return value hashable""" + prompt_length = self.get_prompt_len() + if num_tokens > prompt_length: + return (self._prompt_token_ids_tuple, + tuple(self._output_token_ids[:num_tokens - prompt_length])) + else: + return (self._prompt_token_ids_tuple[:num_tokens], None) + def __repr__(self) -> str: - return (f"SequenceData(" + return (f"SequenceTokenData(" f"prompt_token_ids={self._prompt_token_ids}, " f"output_token_ids={self.output_token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()}") +class _EmbedBase(msgspec.Struct, omit_defaults=True): # type: ignore[call-arg] + _prompt_embeds: torch.Tensor + + +# Note the ordering here so `_prompt_embeds` is the first __init__ argument +class SequenceEmbedData(SequenceDataMixin, _EmbedBase, + omit_defaults=True): # type: ignore[call-arg] + """Data associated with a sequence. + + Args: + prompt_embeds: The embeddings of the prompt. + output_token_ids: The token IDs of the output. Set to an empty list if + None. + cumulative_logprob: The cumulative log probability of the output. + + Attributes: + prompt_embeds: The embeddings of the prompt. + output_token_ids: The token IDs of the output. + cumulative_logprob: The cumulative log probability of the output. + """ + # For tagged union + type: Literal["embeds"] = "embeds" + + @property + def prompt_embeds(self) -> torch.Tensor: + return self._prompt_embeds + + def get_prompt_len(self) -> int: + return len(self._prompt_embeds) + + def __repr__(self) -> str: + return (f"SequenceEmbedData(" + f"prompt_embeds={self._prompt_embeds}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"get_num_computed_tokens={self.get_num_computed_tokens()}") + + +SequenceData = Union[SequenceTokenData, SequenceEmbedData] + + class Sequence: """Stores the data, status, and block information of a sequence. @@ -414,9 +494,14 @@ def __init__( f"invalid input {inputs}; did you forget the " "encoder input prompt fields?") - self.data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids), - self.prompt_embeds) + data: SequenceData + if self.prompt_token_ids: + data = SequenceTokenData.from_seq(self.prompt_token_ids) + else: + data = SequenceEmbedData(self.prompt_embeds) + + self.data = data + self.output_logprobs: SampleLogprobs = [] self.output_text = "" From 7f8ed8c8e97b59a2abb28ad3b462330f8f78493e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 15:27:05 +0000 Subject: [PATCH 35/88] Rename `PromptInputs` to `PromptType` --- benchmarks/benchmark_latency.py | 8 +- .../dev/multimodal/multimodal_index.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 2 +- docs/source/models/vlm.rst | 2 +- tests/mq_llm_engine/test_error_handling.py | 12 +- tests/mq_llm_engine/utils.py | 2 +- vllm/__init__.py | 4 +- vllm/engine/async_llm_engine.py | 24 +- vllm/engine/llm_engine.py | 15 +- vllm/engine/multiprocessing/__init__.py | 4 +- vllm/engine/multiprocessing/client.py | 6 +- vllm/engine/multiprocessing/engine.py | 2 +- vllm/engine/protocol.py | 8 +- vllm/entrypoints/llm.py | 38 +-- vllm/inputs/__init__.py | 12 +- vllm/inputs/data.py | 92 +++--- vllm/inputs/parse.py | 34 +-- vllm/inputs/preprocess.py | 261 ++++++++++-------- 18 files changed, 291 insertions(+), 237 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a39d1cf842f06..eadf994cacd34 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptInputs] = [{ + dummy_prompts: List[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 241b2ccd0991e..e112b43aade5e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 9adf82d43f3e0..0d47281db485e 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptInputs +.. autodata:: vllm.inputs.PromptType .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 08db891665044..ca5b125369c85 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 49cfc5aa04c36..7c466c92d5293 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket): # Throws an error in first forward pass. with pytest.raises(RAISED_ERROR): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket): # Engine is errored, should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket): # Generate call should throw ENGINE_DEAD_ERROR with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -165,7 +165,7 @@ async def bad_abort_after_2s(): # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( - inputs="Hello my name is", + prompt="Hello my name is", sampling_params=SamplingParams(max_tokens=2000), request_id=uuid.uuid4()): pass @@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket): # Invalid request should fail, but not crash the server. with pytest.raises(ValueError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( @@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket): pass # This request should be okay. - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-2"): pass diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index e27fd77923412..3ffa126070ca0 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -20,7 +20,7 @@ async def generate( count = 0 async for out in client.generate( request_id=request_id, - inputs="Hello my name is Robert and", + prompt="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): diff --git a/vllm/__init__.py b/vllm/__init__.py index 0895c571d1d89..59af68fb493e5 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,7 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -19,7 +19,7 @@ "__version__", "LLM", "ModelRegistry", - "PromptInputs", + "PromptType", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 34e7e05341f02..01f76cce82348 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -17,7 +17,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -405,7 +405,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None: async def add_request_async( self, request_id: str, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -420,7 +420,7 @@ async def add_request_async( arrival_time = time.time() preprocessed_inputs = await self.input_preprocessor.preprocess_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -777,7 +777,7 @@ async def run_engine_loop(engine_ref: ReferenceType): async def add_request( self, request_id: str, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -797,7 +797,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, - inputs=inputs, + inputs=prompt, params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, @@ -808,7 +808,7 @@ async def add_request( async def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -822,8 +822,7 @@ async def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -881,7 +880,7 @@ async def generate( """ async for output in await self.add_request( request_id, - inputs, + prompt, sampling_params, lora_request=lora_request, trace_headers=trace_headers, @@ -891,7 +890,7 @@ async def generate( async def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -904,8 +903,7 @@ async def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -959,7 +957,7 @@ async def encode( """ async for output in await self.add_request( request_id, - inputs, + prompt, pooling_params, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7ec03db7c7687..a1170ea52d6af 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -28,8 +28,8 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptInputs) +from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderInputs, InputRegistry, + LLMInputs, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -617,7 +617,7 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], + processed_inputs: Union[LLMInputs, EncoderDecoderInputs], params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -682,7 +682,7 @@ def stop_remote_worker_execution_loop(self) -> None: def add_request( self, request_id: str, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -697,8 +697,7 @@ def add_request( Args: request_id: The unique ID of the request. - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. @@ -738,7 +737,7 @@ def add_request( arrival_time = time.time() preprocessed_inputs = self.input_preprocessor.preprocess( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -1710,7 +1709,7 @@ def _support_prompt_embeds(self) -> Tuple[bool, str]: "input embeddings, but prompt_embeds was provided.") def _validate_model_inputs(self, inputs: Union[LLMInputs, - EncoderDecoderLLMInputs]): + EncoderDecoderInputs]): if self.is_encoder_decoder_model(): prompt_ids = inputs.get("encoder_prompt_token_ids") else: diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index ba5c6e15fc821..3c35a5fcb11a9 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -2,7 +2,7 @@ from enum import Enum from typing import List, Mapping, Optional, Union -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -22,7 +22,7 @@ class MQEngineDeadError(RuntimeError): @dataclass class RPCGenerateRequest: - inputs: PromptInputs + prompt: PromptType sampling_params: SamplingParams request_id: str lora_request: Optional[LoRARequest] = None diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 2cb4de79131f1..0ce00a266df93 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -24,7 +24,7 @@ RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -389,7 +389,7 @@ def dead_error(self) -> BaseException: async def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -421,7 +421,7 @@ async def generate( request_bytes = pickle.dumps( RPCGenerateRequest( - inputs=inputs, + prompt=prompt, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 70cd6e5cb6000..f887512919a80 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -244,7 +244,7 @@ def _handle_generate_request(self, request: RPCGenerateRequest): try: self.engine.add_request( request_id=request_id, - inputs=request.inputs, + prompt=request.prompt, params=request.sampling_params, lora_request=request.lora_request, trace_headers=request.trace_headers, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 70444faa670a2..564837344685e 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptInputs +from vllm.inputs.data import PromptType from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -35,19 +35,19 @@ def dead_error(self) -> BaseException: def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: - """Generates outputs for a request""" + """Generate outputs for a request""" ... def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 248b070611cd2..0456d8a0d72a5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -10,7 +10,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -258,7 +258,7 @@ def generate( @overload def generate( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], + inputs: Union[PromptType, Sequence[PromptType]], /, # We may enable `inputs` keyword after removing the old API *, sampling_params: Optional[Union[SamplingParams, @@ -276,7 +276,7 @@ def generate( ) def generate( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -325,7 +325,7 @@ def generate( prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + inputs = cast(Union[PromptType, Sequence[PromptType]], prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -414,17 +414,17 @@ def chat( tools=tools, ) - inputs: PromptInputs + parsed_prompt: PromptType if is_list_of(prompt, int): - inputs = TokensPrompt(prompt_token_ids=prompt) + parsed_prompt = TokensPrompt(prompt_token_ids=prompt) else: - inputs = TextPrompt(prompt=prompt) + parsed_prompt = TextPrompt(prompt=prompt) if mm_data is not None: - inputs["multi_modal_data"] = mm_data + parsed_prompt["multi_modal_data"] = mm_data return self.generate( - inputs, + parsed_prompt, sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, @@ -494,7 +494,7 @@ def encode( @overload def encode( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], + inputs: Union[PromptType, Sequence[PromptType]], /, # We may enable `inputs` keyword after removing the old API *, pooling_params: Optional[Union[PoolingParams, @@ -512,7 +512,7 @@ def encode( ) def encode( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -528,8 +528,8 @@ def encode( into a single list and pass it to this method. Args: - inputs: The inputs to the LLM. You may pass a sequence of inputs for - batch inference. See :class:`~vllm.inputs.PromptInputs` + prompts: The prompts to the LLM. You may pass a sequence of inputs + for batch inference. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. @@ -558,7 +558,7 @@ def encode( prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + inputs = cast(Union[PromptType, Sequence[PromptType]], prompts) if pooling_params is None: # Use default pooling params. @@ -609,9 +609,9 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - inputs: List[PromptInputs] = [] + inputs: List[PromptType] = [] for i in range(num_requests): - item: PromptInputs + item: PromptType if prompts is not None: item = TextPrompt(prompt=prompts[i]) @@ -626,7 +626,7 @@ def _convert_v1_inputs( def _validate_and_add_requests( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], + inputs: Union[PromptType, Sequence[PromptType]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], @@ -665,7 +665,7 @@ def _validate_and_add_requests( def _add_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -673,7 +673,7 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request( request_id, - inputs, + prompt, params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index d83878019138d..4315d5c2b81e7 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,6 +1,6 @@ -from .data import (EmbedsPrompt, EncoderDecoderLLMInputs, - ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, - SingletonPromptInputs, TextPrompt, TokensPrompt, +from .data import (EmbedsPrompt, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, LLMInputs, PromptType, + SingletonPrompt, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -18,11 +18,11 @@ "TextPrompt", "TokensPrompt", "EmbedsPrompt", - "PromptInputs", - "SingletonPromptInputs", + "PromptType", + "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "LLMInputs", - "EncoderDecoderLLMInputs", + "EncoderDecoderInputs", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 3e233ad12a71e..589c9a83e9500 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,5 +1,5 @@ -from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple, - Union) +from typing import (TYPE_CHECKING, Generic, Iterable, List, Literal, Optional, + Tuple, Union) import torch from typing_extensions import NotRequired, TypedDict, TypeVar @@ -47,7 +47,7 @@ class EmbedsPrompt(TypedDict): """ -SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] +SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ Set of possible schemas for a single LLM input: @@ -61,7 +61,7 @@ class EmbedsPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPromptInputs` may be employed +A prompt of type :class:`SingletonPrompt` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -70,12 +70,12 @@ class EmbedsPrompt(TypedDict): """ _T1_co = TypeVar("_T1_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) _T2_co = TypeVar("_T2_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) @@ -87,7 +87,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): The encoder and decoder prompts, respectively, may formatted according to any of the - :class:`SingletonPromptInputs` schemas, and are not + :class:`SingletonPrompt` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. @@ -96,7 +96,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): be used as an input to a decoder-only model, and that the `encoder_prompt` and `decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPromptInputs` instances. + :class:`SingletonPrompt` instances. """ encoder_prompt: _T1_co @@ -104,7 +104,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] -PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] +PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -116,56 +116,76 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): """ -class LLMInputs(TypedDict): - """ - The inputs in :class:`~vllm.LLMEngine` before they are - passed to the model executor. +class TokenInputs(TypedDict): + """Represents token-based inputs.""" - This specifies the data required for decoder-only models. - """ - prompt_token_ids: Optional[List[int]] + type: Literal["token"] + """The type of inputs.""" + + prompt_token_ids: List[int] """The token IDs of the prompt.""" - prompt: NotRequired[Optional[str]] + prompt: NotRequired[str] """ The original prompt text corresponding to the token IDs, if available. """ - prompt_embeds: NotRequired[Optional[torch.Tensor]] + multi_modal_data: NotRequired["MultiModalDataDict"] """ - The embeddings of the prompt, if available. + Optional multi-modal data to pass to the model, + if the model supports it. """ - multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + +class EmbedInputs(TypedDict): + """Represents embedding-based inputs.""" + + type: Literal["embed"] + """The type of inputs.""" + + prompt_embeds: torch.Tensor + """The embeddings of the prompt.""" + + multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, if the model supports it. """ -class EncoderDecoderLLMInputs(LLMInputs): +LLMInputs = Union[TokenInputs, EmbedInputs] +""" +The inputs in :class:`~vllm.LLMEngine` before they are +passed to the model executor. + +This specifies the data required for decoder-only models. +""" + + +class EmptyInputs(TypedDict): + """Represents empty inputs.""" + + type: Literal["empty"] + """The type of inputs.""" + + +class EncoderDecoderInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. This specifies the required data for encoder-decoder models. """ - encoder_prompt_token_ids: Optional[List[int]] - """The token IDs of the encoder prompt.""" - encoder_prompt: NotRequired[Optional[str]] - """ - The original encoder prompt text corresponding to the token IDs, if - available. - """ + encoder: TokenInputs + """The inputs for the encoder portion.""" + + decoder: Union[EmptyInputs, TokenInputs] + """The inputs for the decoder portion.""" -_T1 = TypeVar("_T1", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) -_T2 = TypeVar("_T2", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) +_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) +_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) def build_explicit_enc_dec_prompt( diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index d6aadf4d49f44..e725e88c226a9 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -4,9 +4,9 @@ from vllm.utils import is_list_of -from .data import (EmbedsPrompt, EncoderDecoderLLMInputs, - ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, - SingletonPromptInputs, TextPrompt, TokensPrompt) +from .data import (EmbedsPrompt, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, LLMInputs, PromptType, + SingletonPrompt, TextPrompt, TokensPrompt) class ParsedText(TypedDict): @@ -86,30 +86,30 @@ class ParsedEmbedsPrompt(TypedDict): def parse_singleton_prompt( - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, ParsedEmbedsPrompt]: - if isinstance(inputs, str): - return ParsedStrPrompt(type="str", content=inputs) - elif isinstance(inputs, dict): - if 'prompt_embeds' in inputs: + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + if 'prompt_embeds' in prompt: return ParsedEmbedsPrompt(type="embeds", - content=inputs) # type: ignore - elif "prompt_token_ids" in inputs: + content=prompt) # type: ignore + elif "prompt_token_ids" in prompt: return ParsedTokensPrompt(type="tokens", - content=inputs) # type: ignore - elif "prompt" in inputs: - return ParsedTextPrompt(type="text", content=inputs) + content=prompt) # type: ignore + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( - inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(inputs, dict) and "encoder_prompt" in inputs + prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(prompt, dict) and "encoder_prompt" in prompt def is_valid_encoder_decoder_llm_inputs( - inputs: Union[LLMInputs, EncoderDecoderLLMInputs], -) -> TypeIs[EncoderDecoderLLMInputs]: + inputs: Union[LLMInputs, EncoderDecoderInputs], +) -> TypeIs[EncoderDecoderInputs]: return "encoder_prompt_token_ids" in inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 5117ca5a66b4e..33b0582210099 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,5 +1,4 @@ import asyncio -from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Union import torch @@ -11,8 +10,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, - SingletonPromptInputs) +from .data import (TokenInputs, EmbedInputs, EncoderDecoderInputs, EmptyInputs, + LLMInputs, PromptType, SingletonPrompt) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: @@ -21,22 +20,6 @@ logger = init_logger(__name__) -@dataclass(frozen=True) -class PromptComponents: - prompt: Optional[str] - prompt_token_ids: List[int] - prompt_embeds: Optional[torch.Tensor] = None - multi_modal_data: Optional["MultiModalDataDict"] = None - - -@dataclass(frozen=True) -class DecoderPromptComponents: - prompt: Optional[str] - prompt_token_ids: Optional[List[int]] - prompt_embeds: Optional[torch.Tensor] = None - multi_modal_data: Optional["MultiModalDataDict"] = None - - class InputPreprocessor: def __init__( @@ -217,12 +200,42 @@ async def _tokenize_prompt_async( prompt=prompt, lora_request=lora_request) - def _extract_prompt_components( + def _token_inputs( + self, + prompt_token_ids: List[int], + prompt: Optional[str] = None, + multi_modal_data: Optional["MultiModalDataDict"] = None, + ) -> TokenInputs: + inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) + + if prompt is not None: + inputs["prompt"] = prompt + if multi_modal_data is not None: + inputs["multi_modal_data"] = multi_modal_data + + return inputs + + def _embed_inputs( self, - inputs: SingletonPromptInputs, + prompt_embeds: torch.Tensor, + multi_modal_data: Optional["MultiModalDataDict"] = None, + ) -> EmbedInputs: + inputs = EmbedInputs(type="embed", prompt_embeds=prompt_embeds) + + if multi_modal_data is not None: + inputs["multi_modal_data"] = multi_modal_data + + return inputs + + def _empty_inputs(self) -> EmptyInputs: + return EmptyInputs(type="empty") + + def _prompt_to_llm_inputs( + self, + inputs: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: + ) -> LLMInputs: ''' Extract the components of any single encoder or decoder input prompt. @@ -249,16 +262,19 @@ def _extract_prompt_components( lora_request=lora_request, ) - return PromptComponents(prompt=prompt, - prompt_token_ids=prompt_token_ids) + return self._token_inputs( + prompt=prompt, + prompt_token_ids=prompt_token_ids, + ) if parsed["type"] == "tokens": prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") - return PromptComponents(prompt=None, - prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data) + return self._token_inputs( + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) if parsed["type"] == "text": prompt = parsed["content"]["prompt"] @@ -269,27 +285,29 @@ def _extract_prompt_components( ) multi_modal_data = parsed["content"].get("multi_modal_data") - return PromptComponents(prompt=prompt, - prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data) + return self._token_inputs( + prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) if parsed["type"] == "embeds": prompt_embeds = parsed["content"]["prompt_embeds"] multi_modal_data = parsed["content"].get("multi_modal_data") - return PromptComponents(prompt=None, - prompt_token_ids=[], - multi_modal_data=multi_modal_data, - prompt_embeds=prompt_embeds) + return self._embed_inputs( + prompt_embeds=prompt_embeds, + multi_modal_data=multi_modal_data, + ) assert_never(parsed) - async def _extract_prompt_components_async( + async def _prompt_to_llm_inputs_async( self, - inputs: SingletonPromptInputs, + inputs: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: + ) -> LLMInputs: """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(inputs) @@ -301,16 +319,19 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) - return PromptComponents(prompt=prompt, - prompt_token_ids=prompt_token_ids) + return self._token_inputs( + prompt=prompt, + prompt_token_ids=prompt_token_ids, + ) if parsed["type"] == "tokens": prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") - return PromptComponents(prompt=None, - prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data) + return self._token_inputs( + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) if parsed["type"] == "text": prompt = parsed["content"]["prompt"] @@ -321,46 +342,63 @@ async def _extract_prompt_components_async( ) multi_modal_data = parsed["content"].get("multi_modal_data") - return PromptComponents(prompt=prompt, - prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data) + return self._token_inputs( + prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) if parsed["type"] == "embeds": prompt_embeds = parsed["content"]["prompt_embeds"] multi_modal_data = parsed["content"].get("multi_modal_data") - return PromptComponents(prompt=None, - prompt_token_ids=[], - multi_modal_data=multi_modal_data, - prompt_embeds=prompt_embeds) + return self._embed_inputs( + prompt_embeds=prompt_embeds, + multi_modal_data=multi_modal_data, + ) assert_never(parsed) def _build_enc_dec_llm_inputs( self, - encoder_comps: PromptComponents, - decoder_comps: Union[PromptComponents, DecoderPromptComponents], - ) -> EncoderDecoderLLMInputs: - if (encoder_comps.multi_modal_data is not None - or decoder_comps.multi_modal_data is not None): - raise ValueError("Multi-modal encoder-decoder models are " - "not supported yet") - - decoder_prompt_ids = self._prepare_decoder_input_ids_for_generation( - decoder_comps.prompt_token_ids) - - return EncoderDecoderLLMInputs( - prompt_token_ids=decoder_prompt_ids, - prompt=decoder_comps.prompt, - encoder_prompt_token_ids=encoder_comps.prompt_token_ids, - encoder_prompt=encoder_comps.prompt, + encoder_inputs: LLMInputs, + decoder_inputs: Union[EmptyInputs, LLMInputs], + ) -> EncoderDecoderInputs: + if encoder_inputs["type"] == "token": + if "multi_modal_data" in encoder_inputs: + raise ValueError("Multi-modal encoder-decoder models are " + "not supported yet") + elif encoder_inputs["type"] == "embed": + raise NotImplementedError + else: + assert_never(encoder_inputs) + + if decoder_inputs["type"] == "token": + if "prompt_token_ids" in decoder_inputs: + decoder_inputs["prompt_token_ids"] = ( + self._prepare_decoder_input_ids_for_generation( + decoder_inputs["prompt_token_ids"])) + + if "multi_modal_data" in decoder_inputs: + raise ValueError("Multi-modal encoder-decoder models are " + "not supported yet") + elif decoder_inputs["type"] == "embed": + raise NotImplementedError + elif decoder_inputs["type"] == "empty": + pass + else: + assert_never(encoder_inputs) + + return EncoderDecoderInputs( + encoder=encoder_inputs, + decoder=decoder_inputs, ) def _process_encoder_decoder_prompt( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, - ) -> EncoderDecoderLLMInputs: + ) -> EncoderDecoderInputs: ''' For encoder/decoder models only: Process an input prompt into an @@ -394,90 +432,89 @@ def _process_encoder_decoder_prompt( * :class:`EncoderDecoderLLMInputs` instance ''' - encoder_comps: PromptComponents - decoder_comps: Union[PromptComponents, DecoderPromptComponents] + encoder_inputs: LLMInputs + decoder_inputs: Union[EmptyInputs, LLMInputs] - if is_explicit_encoder_decoder_prompt(inputs): - encoder_comps = self._extract_prompt_components( - inputs["encoder_prompt"], + if is_explicit_encoder_decoder_prompt(prompt): + encoder_inputs = self._prompt_to_llm_inputs( + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: - decoder_comps = DecoderPromptComponents(prompt=None, - prompt_token_ids=None) + if (decoder_input := prompt["decoder_prompt"]) is None: + decoder_inputs = self._empty_inputs() else: - decoder_comps = self._extract_prompt_components( + decoder_inputs = self._prompt_to_llm_inputs( decoder_input, request_id=request_id, ) else: - encoder_comps = self._extract_prompt_components( - inputs, + encoder_inputs = self._prompt_to_llm_inputs( + prompt, request_id=request_id, ) - decoder_comps = DecoderPromptComponents(prompt=None, - prompt_token_ids=None) + decoder_inputs = self._empty_inputs() - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) async def _process_encoder_decoder_prompt_async( self, - inputs: PromptInputs, + inputs: PromptType, request_id: str, - ) -> EncoderDecoderLLMInputs: + ) -> EncoderDecoderInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" - encoder_comps: PromptComponents - decoder_comps: Union[PromptComponents, DecoderPromptComponents] + encoder_inputs: LLMInputs + decoder_inputs: Union[EmptyInputs, LLMInputs] if is_explicit_encoder_decoder_prompt(inputs): - encoder_task = self._extract_prompt_components_async( + encoder_task = self._prompt_to_llm_inputs_async( inputs["encoder_prompt"], request_id=request_id, ) if (decoder_input := inputs["decoder_prompt"]) is None: - encoder_comps = await encoder_task - decoder_comps = DecoderPromptComponents(prompt=None, - prompt_token_ids=None) + encoder_inputs = await encoder_task + decoder_inputs = self._empty_inputs() else: - decoder_task = self._extract_prompt_components_async( + decoder_task = self._prompt_to_llm_inputs_async( decoder_input, request_id=request_id, ) - encoder_comps, decoder_comps = await asyncio.gather( + encoder_inputs, decoder_inputs = await asyncio.gather( encoder_task, decoder_task) else: - encoder_comps = await self._extract_prompt_components_async( + encoder_inputs = await self._prompt_to_llm_inputs_async( inputs, request_id=request_id, ) - decoder_comps = DecoderPromptComponents(prompt=None, - prompt_token_ids=None) + decoder_inputs = self._empty_inputs() - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) def _build_decoder_only_llm_inputs( self, - prompt_comps: PromptComponents, + prompt_inputs: LLMInputs, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> LLMInputs: - prompt_token_ids = self._apply_prompt_adapter( - prompt_comps.prompt_token_ids, - prompt_adapter_request=prompt_adapter_request, - ) + if prompt_inputs["type"] == "token": + if "prompt_token_ids" in prompt_inputs: + prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( + prompt_inputs["prompt_token_ids"], + prompt_adapter_request=prompt_adapter_request, + ) + elif prompt_inputs["type"] == "embed": + pass + else: + assert_never(prompt_inputs) - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt_comps.prompt, - prompt_embeds=prompt_comps.prompt_embeds, - multi_modal_data=prompt_comps.multi_modal_data) + return prompt_inputs def _process_decoder_only_prompt( self, - inputs: SingletonPromptInputs, + inputs: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -498,7 +535,7 @@ def _process_decoder_only_prompt( * :class:`LLMInputs` instance ''' - prompt_comps = self._extract_prompt_components( + prompt_comps = self._prompt_to_llm_inputs( inputs, request_id=request_id, lora_request=lora_request, @@ -511,13 +548,13 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, - inputs: SingletonPromptInputs, + inputs: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" - prompt_comps = await self._extract_prompt_components_async( + prompt_comps = await self._prompt_to_llm_inputs_async( inputs, request_id=request_id, lora_request=lora_request, @@ -530,11 +567,11 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - inputs: PromptInputs, + inputs: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + ) -> Union[LLMInputs, EncoderDecoderInputs]: """Preprocess the input prompt.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of @@ -558,11 +595,11 @@ def preprocess( async def preprocess_async( self, - inputs: PromptInputs, + inputs: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + ) -> Union[LLMInputs, EncoderDecoderInputs]: """Async version of :meth:`preprocess`.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of From 89b5753f0b9c229af389657a6703997c60b5519c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 15:54:18 +0000 Subject: [PATCH 36/88] Fix type errors --- vllm/engine/llm_engine.py | 13 ++- vllm/inputs/parse.py | 2 +- vllm/sequence.py | 181 +++++++++++++++++++++----------------- 3 files changed, 106 insertions(+), 90 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a1170ea52d6af..876b7aa85781a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -30,6 +30,7 @@ from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderInputs, InputRegistry, LLMInputs, PromptType) +from vllm.inputs.parse import is_valid_encoder_decoder_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -635,14 +636,10 @@ def _add_processed_request( lora_request, prompt_adapter_request) encoder_seq = None - if 'encoder_prompt_token_ids' in processed_inputs: - encoder_seq = Sequence(seq_id, - processed_inputs, - block_size, - eos_token_id, - lora_request, - prompt_adapter_request, - from_decoder_prompt=False) + if is_valid_encoder_decoder_inputs(processed_inputs): + encoder_seq = Sequence(seq_id, processed_inputs, block_size, + eos_token_id, lora_request, + prompt_adapter_request) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e725e88c226a9..a5211bf04cae4 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -109,7 +109,7 @@ def is_explicit_encoder_decoder_prompt( return isinstance(prompt, dict) and "encoder_prompt" in prompt -def is_valid_encoder_decoder_llm_inputs( +def is_valid_encoder_decoder_inputs( inputs: Union[LLMInputs, EncoderDecoderInputs], ) -> TypeIs[EncoderDecoderInputs]: return "encoder_prompt_token_ids" in inputs diff --git a/vllm/sequence.py b/vllm/sequence.py index ee53655fac0c5..4fca8b7071769 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,7 @@ from array import array from collections import defaultdict from dataclasses import dataclass +from functools import cached_property from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, Optional) from typing import Sequence as GenericSequence @@ -13,7 +14,7 @@ import msgspec import torch -from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs +from vllm.inputs.parse import is_valid_encoder_decoder_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -21,7 +22,7 @@ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: - from vllm.inputs import LLMInputs + from vllm.inputs import EncoderDecoderInputs, LLMInputs from vllm.multimodal.base import MultiModalDataDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -166,6 +167,7 @@ class SequenceDataMixin(msgspec.Struct, def __post_init__(self) -> None: assert self._output_token_ids.typecode == "l" + self._update_cached_all_tokens() def _update_cached_all_tokens(self) -> None: @@ -224,6 +226,14 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self._cached_all_token_ids + @abstractmethod + def get_prefix_token_ids( + self, + num_tokens: int, + ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: + """Get prefix tokens, and make the return value hashable""" + raise NotImplementedError + def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens @@ -253,6 +263,14 @@ def get_num_uncomputed_tokens(self) -> int: # prefill for both prompt and output. return self.get_len() - self.get_num_computed_tokens() + @abstractmethod + def get_last_token_id(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_prompt_token_ids(self) -> Tuple[int, ...]: + raise NotImplementedError + def get_output_token_ids(self) -> Tuple[int, ...]: return self.output_token_ids @@ -298,13 +316,13 @@ class SequenceTokenData(SequenceDataMixin, _TokenBase, cumulative_logprob: The cumulative log probability of the output. """ + # For tagged union + type: Literal["tokens"] = "tokens" + ### The below fields should not be passed as an argument ### _prompt_token_ids_tuple: Tuple[int, ...] = msgspec.field(default_factory=tuple) - # For tagged union - type: Literal["tokens"] = "tokens" - @staticmethod def from_seq( prompt_token_ids: GenericSequence[int], @@ -355,6 +373,7 @@ def get_prompt_len(self) -> int: def get_last_token_id(self) -> int: if not self._output_token_ids: return self._prompt_token_ids[-1] + return self._output_token_ids[-1] def get_prompt_token_ids(self) -> Tuple[int, ...]: @@ -364,13 +383,12 @@ def get_prefix_token_ids( self, num_tokens: int, ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: - """Get prefix tokens, and make the return value hashable""" prompt_length = self.get_prompt_len() if num_tokens > prompt_length: return (self._prompt_token_ids_tuple, tuple(self._output_token_ids[:num_tokens - prompt_length])) - else: - return (self._prompt_token_ids_tuple[:num_tokens], None) + + return (self._prompt_token_ids_tuple[:num_tokens], None) def __repr__(self) -> str: return (f"SequenceTokenData(" @@ -400,9 +418,22 @@ class SequenceEmbedData(SequenceDataMixin, _EmbedBase, output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. """ + # For tagged union type: Literal["embeds"] = "embeds" + ### The below fields should not be passed as an argument ### + _dummy_token_ids_tuple: Tuple[int, + ...] = msgspec.field(default_factory=tuple) + + def __post_init__(self) -> None: + assert self._output_token_ids.typecode == "l" + + # Dummy value + self._dummy_token_ids_tuple = tuple([0] * self._prompt_embeds) + + self._update_cached_all_tokens() + @property def prompt_embeds(self) -> torch.Tensor: return self._prompt_embeds @@ -410,6 +441,26 @@ def prompt_embeds(self) -> torch.Tensor: def get_prompt_len(self) -> int: return len(self._prompt_embeds) + def get_last_token_id(self) -> int: + if not self._output_token_ids: + return self._dummy_token_ids_tuple[-1] + + return self._output_token_ids[-1] + + def get_prompt_token_ids(self) -> Tuple[int, ...]: + return self._dummy_token_ids_tuple + + def get_prefix_token_ids( + self, + num_tokens: int, + ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: + prompt_length = self.get_prompt_len() + if num_tokens > prompt_length: + return (self._dummy_token_ids_tuple, + tuple(self._output_token_ids[:num_tokens - prompt_length])) + + return (self._dummy_token_ids_tuple[:num_tokens], None) + def __repr__(self) -> str: return (f"SequenceEmbedData(" f"prompt_embeds={self._prompt_embeds}, " @@ -424,15 +475,10 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. - The sequence is constructed from the LLMInputs instance passed + The sequence is constructed from the LLMInputs (for decoder-only) or + EncoderDecoderInputs (for encoder-decoder) instance passed in through the `inputs` constructor argument. - For encoder/decoder models, LLMInputs encapsulates both a - decoder and encoder prompt, creating an ambiguity about which - prompt to construct the sequence from. The `from_decoder_prompt` - constructor argument signals whether to construct the Sequence - from the LLMInputs decoder prompt, or encoder prompt. - Args: seq_id: The ID of the sequence. inputs: The inputs of the sequence. @@ -441,21 +487,17 @@ class Sequence: eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. lora_request: LoRA request. prompt_adapter_request: Prompt Adapter request. - from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt - (True) or encoder prompt (False.) Must be True - for decoder-only model. """ def __init__( self, seq_id: int, - inputs: "LLMInputs", + inputs: Union["LLMInputs", "EncoderDecoderInputs"], block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - from_decoder_prompt: bool = True, ) -> None: self.seq_id = seq_id self.inputs = inputs @@ -463,41 +505,12 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.from_decoder_prompt = from_decoder_prompt - self._prompt: Optional[str] = None - self._prompt_token_ids: Optional[List[int]] = None - - # For decoder-only models, a Sequence is constructed - # from an LLMInputs instance (the `inputs` arg.) - # - # For encoder/decoder models the same `inputs` - # instance could be utilized to construct either an - # encoder sequence or a decoder sequence, because - # `LLMInputs` has both decoder- and encoder-oriented - # member variables (i.e. it encapsulates both an encoder - # and a decoder prompt.) The decision of which type of sequence - # to generate is determined by the `from_decoder_prompt` argument. - # - # When constructing a encoder sequence - # (`from_decoder_prompt` False) it matters that - # the `LLMInputs` instance stored in `inputs` is valid - # in the sense that its encoder-related member variables are - # populated; below, an exception is raised if this is - # not the case. - # - # When constructing a decoder sequence (`from_decoder_prompt` True) - # it does not matter whether `inputs` has its encoder-related - # member variables populated. - if not (from_decoder_prompt - or is_valid_encoder_decoder_llm_inputs(inputs)): - raise ValueError("Cannot extract encoder input prompt from " - f"invalid input {inputs}; did you forget the " - "encoder input prompt fields?") data: SequenceData if self.prompt_token_ids: data = SequenceTokenData.from_seq(self.prompt_token_ids) else: + assert self.prompt_embeds is not None data = SequenceEmbedData(self.prompt_embeds) self.data = data @@ -522,45 +535,51 @@ def __init__( def n_blocks(self) -> int: return (self.get_len() + self.block_size - 1) // self.block_size - @property + @cached_property def prompt(self) -> Optional[str]: - if self._prompt is not None: - # Reuse precomputed prompt string - return self._prompt - - # Select decoder or encoder input prompt str, - # as appropriate - prompt_key: str = ("prompt" - if self.from_decoder_prompt else "encoder_prompt") + # Select decoder or encoder input prompt str, as appropriate + inputs = self.inputs + if is_valid_encoder_decoder_inputs(inputs): + prompt = inputs["encoder"].get("prompt") + else: + prompt = cast(Optional[str], inputs.get("prompt")) - # Cache prompt - self._prompt = cast(Optional[str], self.inputs.get(prompt_key)) - return self._prompt + return prompt - @property + @cached_property def prompt_token_ids(self) -> List[int]: - if self._prompt_token_ids is not None: - # Reuse precomputed prompt token ids - return self._prompt_token_ids - - # Select decoder or encoder input prompt - # token ids, as appropriate - prompt_token_ids_key: str = ("prompt_token_ids" - if self.from_decoder_prompt else - "encoder_prompt_token_ids") - - # Cache computed prompt token ids - self._prompt_token_ids = cast(List[int], - self.inputs.get(prompt_token_ids_key)) - return self._prompt_token_ids + # Select decoder or encoder input prompt token ids, as appropriate + inputs = self.inputs + if is_valid_encoder_decoder_inputs(inputs): + prompt_token_ids = inputs["encoder"].get("prompt_token_ids") + else: + prompt_token_ids = cast(Optional[List[int]], + inputs.get("prompt_token_ids")) - @property + return prompt_token_ids or [] + + @cached_property def prompt_embeds(self) -> Optional[torch.Tensor]: - return self.inputs.get("prompt_embeds") + # Select decoder or encoder input prompt embeds, as appropriate + inputs = self.inputs + if is_valid_encoder_decoder_inputs(inputs): + prompt_embeds = inputs["encoder"].get("prompt_embeds") + else: + prompt_embeds = cast(Optional[torch.Tensor], + inputs.get("prompt_embeds")) - @property + return prompt_embeds + + @cached_property def multi_modal_data(self) -> "MultiModalDataDict": - return self.inputs.get("multi_modal_data") or {} + inputs = self.inputs + if is_valid_encoder_decoder_inputs(inputs): + multi_modal_data = inputs["encoder"].get("multi_modal_data") + else: + multi_modal_data = cast(Optional["MultiModalDataDict"], + inputs.get("multi_modal_data")) + + return multi_modal_data or {} @property def lora_int_id(self) -> int: From 8e91af3ca8178adce749764e5016c25d194ae557 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 15:54:24 +0000 Subject: [PATCH 37/88] format --- tests/samplers/test_sampler.py | 4 ++-- tests/spec_decode/utils.py | 4 ++-- tests/test_logits_processor.py | 4 ++-- tests/test_sequence.py | 4 ++-- tests/worker/test_encoder_decoder_model_runner.py | 4 ++-- tests/worker/test_model_runner.py | 4 ++-- vllm/inputs/preprocess.py | 4 ++-- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index eed47b57b5f68..fe82485b597f4 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -11,8 +11,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (SamplingParams, SequenceTokenData, - SequenceGroupMetadata) +from vllm.sequence import (SamplingParams, SequenceGroupMetadata, + SequenceTokenData) from vllm.utils import Counter, is_pin_memory_available diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index e18517dc2e4bd..03be62ffa89d4 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -11,8 +11,8 @@ from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceTokenData, SequenceGroupMetadata, - SequenceOutput) + SequenceGroupMetadata, SequenceOutput, + SequenceTokenData) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 3cbd5e7bacb0e..50139c971a6fb 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -8,8 +8,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (SamplingParams, SequenceTokenData, - SequenceGroupMetadata) +from vllm.sequence import (SamplingParams, SequenceGroupMetadata, + SequenceTokenData) from vllm.utils import is_pin_memory_available diff --git a/tests/test_sequence.py b/tests/test_sequence.py index b5515417da50a..9894b730fd091 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,8 +1,8 @@ import pytest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (CompletionSequenceGroupOutput, SequenceTokenData, - SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, SequenceOutput, + SequenceTokenData) from .core.utils import create_dummy_prompt diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 2dc858a66b7ca..36ca90ad6400b 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -5,8 +5,8 @@ import torch from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import (SamplingParams, SequenceTokenData, - SequenceGroupMetadata) +from vllm.sequence import (SamplingParams, SequenceGroupMetadata, + SequenceTokenData) from vllm.utils import is_cpu, make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import _get_graph_batch_size diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 7478cba5f55df..76e8422889510 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -8,8 +8,8 @@ init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (SamplingParams, SequenceTokenData, - SequenceGroupMetadata) +from vllm.sequence import (SamplingParams, SequenceGroupMetadata, + SequenceTokenData) from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 33b0582210099..f7d94523a4f44 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -10,8 +10,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (TokenInputs, EmbedInputs, EncoderDecoderInputs, EmptyInputs, - LLMInputs, PromptType, SingletonPrompt) +from .data import (EmbedInputs, EmptyInputs, EncoderDecoderInputs, LLMInputs, + PromptType, SingletonPrompt, TokenInputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: From 31e2a1b54a5a0bb678f15715bda071fd0ba5cc42 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 16:26:48 +0000 Subject: [PATCH 38/88] Rename `LLMInputs` to `DecoderOnlyInputs` and fix input processing for VLMs --- .../input_processing/model_inputs_index.rst | 2 +- .../decoder_only/vision_language/test_qwen.py | 10 +- tests/samplers/test_sampler.py | 10 +- tests/spec_decode/utils.py | 4 +- tests/test_logits_processor.py | 2 +- tests/test_sequence.py | 2 +- .../test_encoder_decoder_model_runner.py | 12 +-- tests/worker/test_model_runner.py | 16 +-- vllm/engine/llm_engine.py | 8 +- vllm/inputs/__init__.py | 17 ++- vllm/inputs/data.py | 36 ++++++- vllm/inputs/parse.py | 8 +- vllm/inputs/preprocess.py | 101 ++++++------------ vllm/inputs/registry.py | 13 +-- vllm/model_executor/models/blip.py | 22 ++-- vllm/model_executor/models/blip2.py | 27 ++--- vllm/model_executor/models/chameleon.py | 27 ++--- vllm/model_executor/models/clip.py | 24 ++--- vllm/model_executor/models/fuyu.py | 15 +-- vllm/model_executor/models/internvl.py | 12 ++- vllm/model_executor/models/llava.py | 5 +- vllm/model_executor/models/llava_next.py | 5 +- .../model_executor/models/llava_next_video.py | 11 +- vllm/model_executor/models/minicpmv.py | 17 +-- vllm/model_executor/models/paligemma.py | 12 ++- vllm/model_executor/models/phi3v.py | 11 +- vllm/model_executor/models/pixtral.py | 19 ++-- vllm/model_executor/models/qwen.py | 25 ++--- vllm/model_executor/models/qwen2_vl.py | 35 +++--- vllm/model_executor/models/siglip.py | 18 ++-- vllm/model_executor/models/ultravox.py | 17 +-- vllm/sequence.py | 24 +++-- 32 files changed, 299 insertions(+), 268 deletions(-) diff --git a/docs/source/dev/input_processing/model_inputs_index.rst b/docs/source/dev/input_processing/model_inputs_index.rst index 5d895837590ba..f0ec1fea15ddb 100644 --- a/docs/source/dev/input_processing/model_inputs_index.rst +++ b/docs/source/dev/input_processing/model_inputs_index.rst @@ -25,7 +25,7 @@ Module Contents LLM Engine Inputs ----------------- -.. autoclass:: vllm.inputs.LLMInputs +.. autoclass:: vllm.inputs.DecoderOnlyInputs :members: :show-inheritance: diff --git a/tests/models/decoder_only/vision_language/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py index e4f79092b7606..37c2e642d0fb6 100644 --- a/tests/models/decoder_only/vision_language/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -6,7 +6,7 @@ from PIL.Image import Image from vllm.config import ModelConfig -from vllm.inputs import InputContext, LLMInputs +from vllm.inputs import InputContext, TokenInputs from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size @@ -98,7 +98,7 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen, """Happy cases for image inputs to Qwen's multimodal input processor.""" prompt = "".join( [f"Picture {num}: \n" for num in range(1, num_images + 1)]) - inputs = LLMInputs( + inputs = TokenInputs( prompt=prompt, # When processing multimodal data for a multimodal model, the qwen # input processor will overwrite the provided prompt_token_ids with @@ -161,9 +161,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen, trust_remote_code=True) prompt = "Picture 1: \n" prompt_token_ids = tokenizer.encode(prompt) - inputs = LLMInputs(prompt=prompt, - prompt_token_ids=prompt_token_ids, - multi_modal_data=mm_data) + inputs = TokenInputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=mm_data) # Should fail since we have too many or too few dimensions for embeddings with pytest.raises(ValueError): input_processor_for_qwen(qwen_vl_context, inputs) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index fe82485b597f4..e8f5e83e01aac 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -58,7 +58,7 @@ def _do_sample( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceTokenData.from_seq([1, 2, 3])}, + seq_data={0: SequenceTokenData.from_seqs([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -202,7 +202,7 @@ def create_sampling_params(min_tokens, return sampling_params def create_sequence_data(num_input=3, num_generated=0): - seq_data = SequenceTokenData.from_seq( + seq_data = SequenceTokenData.from_seqs( random.choices(range(0, VOCAB_SIZE), k=num_input)) if num_generated > 0: seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), @@ -507,7 +507,7 @@ def test_sampler_mixed(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceTokenData.from_seq([1, 2, 3])}, + seq_data={0: SequenceTokenData.from_seqs([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -607,7 +607,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceTokenData.from_seq([1, 2, 3])}, + seq_data={0: SequenceTokenData.from_seqs([1, 2, 3])}, sampling_params=SamplingParams( temperature=1, top_k=top_k, @@ -691,7 +691,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceTokenData.from_seq([1, 2, 3])}, + seq_data={0: SequenceTokenData.from_seqs([1, 2, 3])}, sampling_params=sampling_params[i], block_tables={0: [1]}, )) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 03be62ffa89d4..6183df42961ee 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -138,8 +138,8 @@ def create_seq_group_metadata_from_prompts( is_prompt=len(cont_token_ids) == 0, seq_data={ i: - SequenceTokenData.from_seq(prompt_token_ids[:], - cont_token_ids[:]), + SequenceTokenData.from_seqs(prompt_token_ids[:], + cont_token_ids[:]), }, sampling_params=SamplingParams(temperature=0.0, ), block_tables={i: block_allocations[i][:]}, diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 50139c971a6fb..148e663b2b799 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -70,7 +70,7 @@ def pick_ith(token_ids, logits): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceTokenData.from_seq([1, 2, 3])}, + seq_data={0: SequenceTokenData.from_seqs([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 9894b730fd091..3afdc1bf64068 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -55,7 +55,7 @@ def test_sampler_output_eq(sample_outputs): def test_sequence_data_prefill(): - seq_data = SequenceTokenData.from_seq([1, 2, 3, 4]) + seq_data = SequenceTokenData.from_seqs([1, 2, 3, 4]) assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_computed_tokens() == 0 # advance by 2 diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 36ca90ad6400b..5f5df143e738c 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -118,10 +118,10 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceTokenData.from_seq(range(seq_len)) + seq_data = SequenceTokenData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceTokenData.from_seq(range(encoder_seq_len)) + encoder_seq_data = SequenceTokenData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -314,9 +314,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceTokenData.from_seq(range(seq_len)) + seq_data = SequenceTokenData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceTokenData.from_seq(range(encoder_seq_len)) + encoder_seq_data = SequenceTokenData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -518,9 +518,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceTokenData.from_seq(range(seq_len)) + seq_data = SequenceTokenData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceTokenData.from_seq(range(encoder_seq_len)) + encoder_seq_data = SequenceTokenData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 76e8422889510..5d177eb326a47 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -51,13 +51,13 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio): seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceTokenData.from_seq( + seq_data = SequenceTokenData.from_seqs( range(seq_len), torch.rand(seq_len, 10).tolist(), ) input_embeds_len += seq_len else: - seq_data = SequenceTokenData.from_seq(range(seq_len)) + seq_data = SequenceTokenData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -184,13 +184,13 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceTokenData.from_seq( + seq_data = SequenceTokenData.from_seqs( [], torch.rand(context_len, 10).tolist(), ) input_embeds_len += context_len else: - seq_data = SequenceTokenData.from_seq(range(context_len)) + seq_data = SequenceTokenData.from_seqs(range(context_len)) seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. seq_data.append_token_id(1, 0) @@ -364,13 +364,13 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceTokenData.from_seq( + seq_data = SequenceTokenData.from_seqs( [], torch.rand(seq_len, 10).tolist(), ) input_embeds_len += seq_len else: - seq_data = SequenceTokenData.from_seq(range(seq_len)) + seq_data = SequenceTokenData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -387,12 +387,12 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 if random.random() < prompt_embeds_ratio: - seq_data = SequenceTokenData.from_seq( + seq_data = SequenceTokenData.from_seqs( [], torch.rand(context_len, 10).tolist(), ), else: - seq_data = SequenceTokenData.from_seq(range(context_len)) + seq_data = SequenceTokenData.from_seqs(range(context_len)) seq_data.append_token_id(1, 0) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 876b7aa85781a..2fd2d2b012b1b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -28,8 +28,8 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderInputs, InputRegistry, - LLMInputs, PromptType) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, + EncoderDecoderInputs, InputRegistry, PromptType) from vllm.inputs.parse import is_valid_encoder_decoder_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -618,7 +618,7 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[LLMInputs, EncoderDecoderInputs], + processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -1705,7 +1705,7 @@ def _support_prompt_embeds(self) -> Tuple[bool, str]: return False, (f"Model {self.model_config.model} does not support " "input embeddings, but prompt_embeds was provided.") - def _validate_model_inputs(self, inputs: Union[LLMInputs, + def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]): if self.is_encoder_decoder_model(): prompt_ids = inputs.get("encoder_prompt_token_ids") diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 4315d5c2b81e7..cf589fb11d484 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,8 @@ -from .data import (EmbedsPrompt, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, LLMInputs, PromptType, - SingletonPrompt, TextPrompt, TokensPrompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, +from .data import (DecoderOnlyInputs, EmbedInputs, EmbedsPrompt, EmptyInputs, + EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, + PromptType, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, build_explicit_enc_dec_prompt, embed_inputs, + empty_inputs, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -21,7 +22,13 @@ "PromptType", "SingletonPrompt", "ExplicitEncoderDecoderPrompt", - "LLMInputs", + "TokenInputs", + "token_inputs", + "EmbedInputs", + "embed_inputs", + "DecoderOnlyInputs", + "EmptyInputs", + "empty_inputs", "EncoderDecoderInputs", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 589c9a83e9500..56cfef9a1d3c8 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -137,6 +137,22 @@ class TokenInputs(TypedDict): """ +def token_inputs( + prompt_token_ids: List[int], + prompt: Optional[str] = None, + multi_modal_data: Optional["MultiModalDataDict"] = None, +) -> TokenInputs: + """Construct :class:`TokenInputs` from optional values.""" + inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) + + if prompt is not None: + inputs["prompt"] = prompt + if multi_modal_data is not None: + inputs["multi_modal_data"] = multi_modal_data + + return inputs + + class EmbedInputs(TypedDict): """Represents embedding-based inputs.""" @@ -153,7 +169,20 @@ class EmbedInputs(TypedDict): """ -LLMInputs = Union[TokenInputs, EmbedInputs] +def embed_inputs( + prompt_embeds: torch.Tensor, + multi_modal_data: Optional["MultiModalDataDict"] = None, +) -> EmbedInputs: + """Construct :class:`EmbedInputs` from optional values.""" + inputs = EmbedInputs(type="embed", prompt_embeds=prompt_embeds) + + if multi_modal_data is not None: + inputs["multi_modal_data"] = multi_modal_data + + return inputs + + +DecoderOnlyInputs = Union[TokenInputs, EmbedInputs] """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. @@ -169,6 +198,11 @@ class EmptyInputs(TypedDict): """The type of inputs.""" +def empty_inputs() -> EmptyInputs: + """Construct :class:`EmptyInputs` from optional values.""" + return EmptyInputs(type="empty") + + class EncoderDecoderInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index a5211bf04cae4..6c5be9007bced 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -4,9 +4,9 @@ from vllm.utils import is_list_of -from .data import (EmbedsPrompt, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, LLMInputs, PromptType, - SingletonPrompt, TextPrompt, TokensPrompt) +from .data import (DecoderOnlyInputs, EmbedsPrompt, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt, + TextPrompt, TokensPrompt) class ParsedText(TypedDict): @@ -110,6 +110,6 @@ def is_explicit_encoder_decoder_prompt( def is_valid_encoder_decoder_inputs( - inputs: Union[LLMInputs, EncoderDecoderInputs], + inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], ) -> TypeIs[EncoderDecoderInputs]: return "encoder_prompt_token_ids" in inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index f7d94523a4f44..2c80c5b13002a 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,7 +1,6 @@ import asyncio -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union -import torch from typing_extensions import assert_never from vllm.config import ModelConfig @@ -10,13 +9,11 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EmbedInputs, EmptyInputs, EncoderDecoderInputs, LLMInputs, - PromptType, SingletonPrompt, TokenInputs) +from .data import (DecoderOnlyInputs, EmptyInputs, EncoderDecoderInputs, + PromptType, SingletonPrompt, embed_inputs, empty_inputs, + token_inputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt -if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict - logger = init_logger(__name__) @@ -200,42 +197,12 @@ async def _tokenize_prompt_async( prompt=prompt, lora_request=lora_request) - def _token_inputs( - self, - prompt_token_ids: List[int], - prompt: Optional[str] = None, - multi_modal_data: Optional["MultiModalDataDict"] = None, - ) -> TokenInputs: - inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) - - if prompt is not None: - inputs["prompt"] = prompt - if multi_modal_data is not None: - inputs["multi_modal_data"] = multi_modal_data - - return inputs - - def _embed_inputs( - self, - prompt_embeds: torch.Tensor, - multi_modal_data: Optional["MultiModalDataDict"] = None, - ) -> EmbedInputs: - inputs = EmbedInputs(type="embed", prompt_embeds=prompt_embeds) - - if multi_modal_data is not None: - inputs["multi_modal_data"] = multi_modal_data - - return inputs - - def _empty_inputs(self) -> EmptyInputs: - return EmptyInputs(type="empty") - def _prompt_to_llm_inputs( self, inputs: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> LLMInputs: + ) -> DecoderOnlyInputs: ''' Extract the components of any single encoder or decoder input prompt. @@ -262,7 +229,7 @@ def _prompt_to_llm_inputs( lora_request=lora_request, ) - return self._token_inputs( + return token_inputs( prompt=prompt, prompt_token_ids=prompt_token_ids, ) @@ -271,7 +238,7 @@ def _prompt_to_llm_inputs( prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") - return self._token_inputs( + return token_inputs( prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data, ) @@ -285,7 +252,7 @@ def _prompt_to_llm_inputs( ) multi_modal_data = parsed["content"].get("multi_modal_data") - return self._token_inputs( + return token_inputs( prompt=prompt, prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data, @@ -295,7 +262,7 @@ def _prompt_to_llm_inputs( prompt_embeds = parsed["content"]["prompt_embeds"] multi_modal_data = parsed["content"].get("multi_modal_data") - return self._embed_inputs( + return embed_inputs( prompt_embeds=prompt_embeds, multi_modal_data=multi_modal_data, ) @@ -307,7 +274,7 @@ async def _prompt_to_llm_inputs_async( inputs: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> LLMInputs: + ) -> DecoderOnlyInputs: """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(inputs) @@ -319,7 +286,7 @@ async def _prompt_to_llm_inputs_async( lora_request=lora_request, ) - return self._token_inputs( + return token_inputs( prompt=prompt, prompt_token_ids=prompt_token_ids, ) @@ -328,7 +295,7 @@ async def _prompt_to_llm_inputs_async( prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") - return self._token_inputs( + return token_inputs( prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data, ) @@ -342,7 +309,7 @@ async def _prompt_to_llm_inputs_async( ) multi_modal_data = parsed["content"].get("multi_modal_data") - return self._token_inputs( + return token_inputs( prompt=prompt, prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data, @@ -352,7 +319,7 @@ async def _prompt_to_llm_inputs_async( prompt_embeds = parsed["content"]["prompt_embeds"] multi_modal_data = parsed["content"].get("multi_modal_data") - return self._embed_inputs( + return embed_inputs( prompt_embeds=prompt_embeds, multi_modal_data=multi_modal_data, ) @@ -361,8 +328,8 @@ async def _prompt_to_llm_inputs_async( def _build_enc_dec_llm_inputs( self, - encoder_inputs: LLMInputs, - decoder_inputs: Union[EmptyInputs, LLMInputs], + encoder_inputs: DecoderOnlyInputs, + decoder_inputs: Union[EmptyInputs, DecoderOnlyInputs], ) -> EncoderDecoderInputs: if encoder_inputs["type"] == "token": if "multi_modal_data" in encoder_inputs: @@ -402,7 +369,7 @@ def _process_encoder_decoder_prompt( ''' For encoder/decoder models only: Process an input prompt into an - :class:`EncoderDecoderLLMInputs` instance. + :class:`EncoderDecoderDecoderOnlyInputs` instance. There are two types of input prompts: singleton prompts which carry only the @@ -429,11 +396,11 @@ def _process_encoder_decoder_prompt( Returns: - * :class:`EncoderDecoderLLMInputs` instance + * :class:`EncoderDecoderDecoderOnlyInputs` instance ''' - encoder_inputs: LLMInputs - decoder_inputs: Union[EmptyInputs, LLMInputs] + encoder_inputs: DecoderOnlyInputs + decoder_inputs: Union[EmptyInputs, DecoderOnlyInputs] if is_explicit_encoder_decoder_prompt(prompt): encoder_inputs = self._prompt_to_llm_inputs( @@ -442,7 +409,7 @@ def _process_encoder_decoder_prompt( ) if (decoder_input := prompt["decoder_prompt"]) is None: - decoder_inputs = self._empty_inputs() + decoder_inputs = empty_inputs() else: decoder_inputs = self._prompt_to_llm_inputs( decoder_input, @@ -454,7 +421,7 @@ def _process_encoder_decoder_prompt( request_id=request_id, ) - decoder_inputs = self._empty_inputs() + decoder_inputs = empty_inputs() return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) @@ -464,8 +431,8 @@ async def _process_encoder_decoder_prompt_async( request_id: str, ) -> EncoderDecoderInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" - encoder_inputs: LLMInputs - decoder_inputs: Union[EmptyInputs, LLMInputs] + encoder_inputs: DecoderOnlyInputs + decoder_inputs: Union[EmptyInputs, DecoderOnlyInputs] if is_explicit_encoder_decoder_prompt(inputs): encoder_task = self._prompt_to_llm_inputs_async( @@ -475,7 +442,7 @@ async def _process_encoder_decoder_prompt_async( if (decoder_input := inputs["decoder_prompt"]) is None: encoder_inputs = await encoder_task - decoder_inputs = self._empty_inputs() + decoder_inputs = empty_inputs() else: decoder_task = self._prompt_to_llm_inputs_async( decoder_input, @@ -490,15 +457,15 @@ async def _process_encoder_decoder_prompt_async( request_id=request_id, ) - decoder_inputs = self._empty_inputs() + decoder_inputs = empty_inputs() return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) def _build_decoder_only_llm_inputs( self, - prompt_inputs: LLMInputs, + prompt_inputs: DecoderOnlyInputs, prompt_adapter_request: Optional[PromptAdapterRequest], - ) -> LLMInputs: + ) -> DecoderOnlyInputs: if prompt_inputs["type"] == "token": if "prompt_token_ids" in prompt_inputs: prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( @@ -518,10 +485,10 @@ def _process_decoder_only_prompt( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: + ) -> DecoderOnlyInputs: ''' For decoder-only models: - Process an input prompt into an :class:`LLMInputs` instance. + Process an input prompt into an :class:`DecoderOnlyInputs` instance. Arguments: @@ -532,7 +499,7 @@ def _process_decoder_only_prompt( Returns: - * :class:`LLMInputs` instance + * :class:`DecoderOnlyInputs` instance ''' prompt_comps = self._prompt_to_llm_inputs( @@ -552,7 +519,7 @@ async def _process_decoder_only_prompt_async( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: + ) -> DecoderOnlyInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._prompt_to_llm_inputs_async( inputs, @@ -571,7 +538,7 @@ def preprocess( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderInputs]: + ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: """Preprocess the input prompt.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of @@ -599,7 +566,7 @@ async def preprocess_async( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderInputs]: + ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: """Async version of :meth:`preprocess`.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index f74b1ea4e5825..57265b151025c 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger -from .data import LLMInputs +from .data import DecoderOnlyInputs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -97,7 +97,7 @@ def __getitem__(self, key: str) -> int: raise KeyError(msg) from exc -InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs] +InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs] """Preprocess the inputs to the model.""" @@ -129,7 +129,7 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceTokenData - dummy_seq_data = SequenceTokenData.from_seq([0] * seq_len) + dummy_seq_data = SequenceTokenData.from_seqs([0] * seq_len) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data @@ -204,8 +204,9 @@ def dummy_data_for_profiling( return seq_data, mm_data - def _default_input_processor(self, ctx: InputContext, - inputs: LLMInputs) -> LLMInputs: + def _default_input_processor( + self, ctx: InputContext, + inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: """The default input processor is a no-op.""" return inputs @@ -234,7 +235,7 @@ def wrapper(model_cls: N) -> N: return wrapper def process_input(self, model_config: "ModelConfig", - inputs: LLMInputs) -> LLMInputs: + inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: """ Apply an input processor to an instance of model inputs. diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 583d5d217903b..168ec822858cd 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,6 +1,5 @@ """Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" -from array import array from typing import Optional, Union import torch @@ -11,7 +10,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import LLMInputs +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -19,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import SequenceTokenData try: from xformers import ops as xops @@ -62,11 +61,10 @@ def dummy_seq_data_for_blip( else: image_feature_size = image_feature_size_override - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * image_feature_size - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size) - return SequenceData(token_ids) + return SequenceTokenData.from_counts({ + image_token_id: image_feature_size, + 0: (seq_len - image_feature_size), + }) def dummy_image_for_blip( @@ -89,7 +87,7 @@ def dummy_image_for_blip( def input_processor_for_blip( model_config: ModelConfig, hf_config: Union[BlipVisionConfig, Blip2VisionConfig], - llm_inputs: LLMInputs, + llm_inputs: DecoderOnlyInputs, *, image_token_id: int, image_feature_size_override: Optional[int] = None, @@ -114,9 +112,9 @@ def input_processor_for_blip( ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 39f2b2d853a6b..7ebf2c529c434 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,4 +1,3 @@ -from array import array from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -9,7 +8,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -18,8 +18,7 @@ from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceTokenData from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) @@ -429,11 +428,12 @@ def dummy_seq_data_for_blip2( else: image_feature_size = image_feature_size_override - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * image_feature_size * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceTokenData.from_counts({ + image_token_id: + image_feature_size * num_images, + 0: + seq_len - image_feature_size * num_images, + }) def dummy_data_for_blip2(ctx: InputContext, seq_len: int, @@ -458,7 +458,8 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, raise NotImplementedError(msg) -def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_blip2(ctx: InputContext, + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -475,9 +476,9 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): if new_prompt is not None: new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) @MULTIMODAL_REGISTRY.register_image_input_mapper() diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 47e020e8ecb73..ef7a326fcac91 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,4 +1,3 @@ -from array import array from functools import cached_property from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict) @@ -12,7 +11,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -32,8 +32,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceTokenData from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal @@ -72,11 +71,12 @@ def dummy_seq_data_for_chameleon( else: image_feature_size = image_feature_size_override - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * image_feature_size * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceTokenData.from_counts({ + image_token_id: + image_feature_size * num_images, + 0: + seq_len - image_feature_size * num_images, + }) def dummy_image_for_chameleon( @@ -110,7 +110,8 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, return seq_data, mm_data -def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_chameleon(ctx: InputContext, + llm_inputs: DecoderOnlyInputs): """ Processing input prompt to insert required tokens for image placeholder. @@ -141,9 +142,9 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): new_token_ids += [CHAMELEON_SEP_TOKEN_ID] # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class ChameleonLayerNorm(nn.LayerNorm): diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 078928f281c26..8ea02ea6a84ac 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,6 +1,5 @@ """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from array import array from typing import Iterable, List, Optional, Tuple, Union import torch @@ -11,7 +10,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import LLMInputs +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -20,7 +19,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import SequenceTokenData try: from xformers import ops as xops @@ -62,11 +61,12 @@ def dummy_seq_data_for_clip( else: image_feature_size = image_feature_size_override - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * image_feature_size * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceTokenData.from_counts({ + image_token_id: + image_feature_size * num_images, + 0: + seq_len - image_feature_size * num_images, + }) def dummy_image_for_clip( @@ -89,7 +89,7 @@ def dummy_image_for_clip( def input_processor_for_clip( model_config: ModelConfig, hf_config: CLIPVisionConfig, - llm_inputs: LLMInputs, + llm_inputs: DecoderOnlyInputs, *, image_token_id: int, image_feature_size_override: Optional[Union[int, List[int]]] = None, @@ -120,9 +120,9 @@ def input_processor_for_clip( ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 3e3fae62ce92f..50f80a76a8b0f 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -27,7 +27,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig @@ -40,7 +41,7 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) + SequenceTokenData) from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings @@ -106,7 +107,7 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceTokenData.from_seqs(token_ids) def dummy_image_for_fuyu( @@ -153,7 +154,7 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, return model_image_input -def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_fuyu(ctx: InputContext, llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -194,9 +195,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ 1:] + boa_token - return LLMInputs(prompt=new_prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=new_multi_modal_data) + return token_inputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=new_multi_modal_data) def input_mapper_for_fuyu(ctx: InputContext, data: object): diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 507d7014714a2..0613ac7507b4f 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -18,7 +18,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -193,7 +194,8 @@ def get_max_internvl_image_tokens(ctx: InputContext): return num_patches * max_dynamic_patch -def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_internvl(ctx: InputContext, + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -248,9 +250,9 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): new_prompt = new_prompt.replace('', image_prompt, 1) new_prompt_token_ids = tokenizer.encode(new_prompt) - return LLMInputs(prompt=prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return token_inputs(prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) def input_mapper_for_internvl(ctx: InputContext, data: object): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7a6c991fb133a..8992830acf431 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -9,7 +9,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -126,7 +126,8 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, raise NotImplementedError(msg) -def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_llava(ctx: InputContext, + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c6bd46dd7eda9..9862f44217056 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -12,7 +12,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -208,7 +208,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, raise NotImplementedError(msg) -def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_llava_next(ctx: InputContext, + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 7fe85e5e4ab3d..4e02409f29731 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -11,7 +11,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization.base_config import ( @@ -144,7 +145,7 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, def input_processor_for_llava_next_video(ctx: InputContext, - llm_inputs: LLMInputs): + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "video" not in multi_modal_data: return llm_inputs @@ -171,9 +172,9 @@ def input_processor_for_llava_next_video(ctx: InputContext, repeat_count=video_feature_size, ) - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index f0fc950defed7..8b79636101a88 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -23,7 +23,6 @@ """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re -from array import array from functools import partial from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, TypedDict) @@ -37,7 +36,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -56,8 +56,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceTokenData from .idefics2_vision_model import Idefics2VisionTransformer @@ -259,8 +258,9 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len - return SequenceData(token_ids) + return SequenceTokenData.from_counts({ + 0: seq_len, + }) def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int): @@ -280,7 +280,8 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, return seq_data, mm_data -def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_minicpmv(ctx: InputContext, + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -323,7 +324,7 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): new_prompt = "".join(new_prompt_chunks) new_token_ids = tokenizer.encode(new_prompt) - llm_inputs = LLMInputs( + llm_inputs = token_inputs( prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 5fd39b5e35be6..8743ffe90f7f6 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -8,7 +8,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -71,7 +72,8 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, return seq_data, mm_data -def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_paligemma(ctx: InputContext, + llm_inputs: DecoderOnlyInputs): """ The correct prompt format needs to be: @@ -109,9 +111,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class PaliGemmaMultiModalProjector(nn.Module): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 94de0c5ae9edf..ae467232dcc6c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -27,7 +27,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -398,7 +398,8 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig, return image_placeholder_token_ids -def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_phi3v(ctx: InputContext, + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -485,9 +486,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): new_token_ids.append(token_id) # NOTE: Create a defensive copy of the original inputs - llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + llm_inputs = DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) return llm_inputs diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 682b78bbed093..b236e36f62c09 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,4 +1,3 @@ -from array import array from dataclasses import dataclass, fields from itertools import tee from typing import Iterable, List, Mapping, Optional, Tuple, Union @@ -14,7 +13,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -24,8 +23,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceTokenData from .interfaces import SupportsMultiModal from .utils import init_vllm_registered_model @@ -63,13 +61,11 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, image_feature_size = (size**2) // (patch_size**2) num_image_tokens = image_feature_size * num_images + seq_data = SequenceTokenData.from_counts({ + image_token_id: num_image_tokens, + 0: seq_len - num_image_tokens, + }) - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * num_image_tokens - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - num_image_tokens) - - seq_data = SequenceData(token_ids) mm_data = {"image": num_images * [image]} return seq_data, mm_data @@ -105,7 +101,8 @@ def input_mapper_for_pixtral(ctx: InputContext, return MultiModalInputs({"images": images}) -def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_pixtral(ctx: InputContext, + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is not None and "image" in multi_modal_data: tokenizer = cached_get_tokenizer( diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 18bc6b303f485..d284ee2862547 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -7,7 +7,6 @@ import math import re -from array import array from functools import partial from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -23,7 +22,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -45,8 +45,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceTokenData from vllm.utils import is_list_of from .utils import flatten_bn, is_pp_missing_parameter, make_layers @@ -650,8 +649,8 @@ def get_image_text(image_num: int, padding: bool) -> str: return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}" -def input_processor_for_qwen(ctx: InputContext, - llm_inputs: LLMInputs) -> LLMInputs: +def input_processor_for_qwen( + ctx: InputContext, llm_inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: """Processes the inputs, which may or may not be multimodal. Multimodal inputs will only be processed if the model has a "visual" component in its model config, otherwise they'll be ignored. @@ -711,9 +710,9 @@ def input_processor_for_qwen(ctx: InputContext, new_prompt_token_ids = tokenizer.encode(new_prompt) - return LLMInputs(prompt=new_prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return token_inputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: @@ -802,7 +801,7 @@ def dummy_data_for_qwen( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], -) -> Tuple[SequenceData, Optional[Dict]]: +) -> Tuple[SequenceTokenData, Optional[Dict]]: """Build dummy data for warming up Qwen models; this will only contain text matching the defaults for VLLM unless the model has a visual config. @@ -819,7 +818,7 @@ def dummy_data_for_qwen( # The presence of a visual config indicates this is a multimodal model. # If we don't have it, the model is considered an LLM for warmup purposes. if not hasattr(hf_config, "visual"): - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)) + seq_data = SequenceTokenData.from_counts({0: seq_len}) mm_data = None return seq_data, mm_data @@ -846,11 +845,13 @@ def dummy_data_for_qwen( if len(toks) < seq_len: toks += [0] * (seq_len - len(toks)) + seq_data = SequenceTokenData.from_seqs(toks) + # Build the input images; width/height doesn't actually matter here since # the data will get resized and the # of tokens per image is constant image = Image.new("RGB", (224, 224), color=0) mm_data = {"image": image if num_images == 1 else [image] * num_images} - return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data + return seq_data, mm_data @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a9a0329e99f08..0d97a5958164c 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,7 +22,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -from array import array from functools import lru_cache, partial from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union) @@ -48,7 +47,8 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -66,8 +66,7 @@ from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor from vllm.platforms import current_platform -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceTokenData from vllm.transformers_utils.processor import get_processor logger = init_logger(__name__) @@ -656,7 +655,7 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: def dummy_data_for_qwen2_vl( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] -) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: +) -> Tuple[SequenceTokenData, Optional[MultiModalDataDict]]: image_processor = cached_get_image_processor(ctx.model_config.model) num_images = mm_counts["image"] @@ -681,15 +680,17 @@ def dummy_data_for_qwen2_vl( "--limit-mm-per-prompt.") hf_config = ctx.get_hf_config(Qwen2VLConfig) - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [hf_config.vision_start_token_id]) - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [hf_config.image_token_id]) * max_llm_image_tokens - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [hf_config.vision_end_token_id]) - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - max_llm_image_tokens - 2) - dummy_seqdata = SequenceData(token_ids) + + dummy_seqdata = SequenceTokenData.from_counts({ + hf_config.vision_start_token_id: + 1, + hf_config.image_token_id: + max_llm_image_tokens, + hf_config.vision_end_token_id: + 1, + 0: + seq_len - max_llm_image_tokens - 2, + }) dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), color=0) @@ -724,8 +725,8 @@ def _get_llm_num_vision_tokens( return llm_num_vision_tokens -def input_processor_for_qwen2_vl(ctx: InputContext, - llm_inputs: LLMInputs) -> LLMInputs: +def input_processor_for_qwen2_vl( + ctx: InputContext, llm_inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: multi_modal_data = llm_inputs.get("multi_modal_data", None) if multi_modal_data is None: return llm_inputs @@ -817,7 +818,7 @@ def input_processor_for_qwen2_vl(ctx: InputContext, 1:]) prompt_token_ids = prompt_token_ids_with_video - return LLMInputs( + return token_inputs( prompt_token_ids=prompt_token_ids, prompt=llm_inputs["prompt"], multi_modal_data=multi_modal_data, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index f7976eba7420b..1c89682cd76b5 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -2,7 +2,6 @@ within a vision language model.""" import math -from array import array from typing import Iterable, List, Optional, Tuple, Union import torch @@ -13,7 +12,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import LLMInputs +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -24,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import SequenceTokenData try: from xformers import ops as xops @@ -67,11 +66,10 @@ def dummy_seq_data_for_siglip( else: image_feature_size = image_feature_size_override - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * image_feature_size - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size) - return SequenceData(token_ids) + return SequenceTokenData.from_counts({ + image_token_id: image_feature_size, + 0: (seq_len - image_feature_size), + }) def dummy_image_for_siglip( @@ -94,7 +92,7 @@ def dummy_image_for_siglip( def input_processor_for_siglip( model_config: ModelConfig, hf_config: SiglipVisionConfig, - llm_inputs: LLMInputs, + llm_inputs: DecoderOnlyInputs, *, image_token_id: int, image_feature_size_override: Optional[Union[int, List[int]]] = None, @@ -125,7 +123,7 @@ def input_processor_for_siglip( ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs( + return token_inputs( prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 416fabda831a2..81e047514f1ea 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -19,7 +19,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY -from vllm.inputs.data import LLMInputs +from vllm.inputs.data import DecoderOnlyInputs, token_inputs from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn @@ -37,7 +37,7 @@ from vllm.multimodal.base import MultiModalInputs, NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceTokenData from vllm.transformers_utils.configs.ultravox import UltravoxConfig _AUDIO_PLACEHOLDER_TOKEN = 128002 @@ -96,10 +96,12 @@ def dummy_data_for_ultravox( other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids)) + seq_data = SequenceTokenData.from_seqs(audio_token_ids + other_token_ids) + audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) mm_dict = {"audio": [audio_and_sr] * audio_count} - return (SequenceData(audio_token_ids + other_token_ids), mm_dict) + return (seq_data, mm_dict) def input_mapper_for_ultravox(ctx: InputContext, data: object): @@ -141,7 +143,8 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): return MultiModalInputs({"audio_features": audio_features}) -def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_ultravox(ctx: InputContext, + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "audio" not in multi_modal_data: return llm_inputs @@ -183,9 +186,9 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class StackAudioFrames(nn.Module): diff --git a/vllm/sequence.py b/vllm/sequence.py index 4fca8b7071769..77ec1c762d2d6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,7 +5,7 @@ from array import array from collections import defaultdict from dataclasses import dataclass -from functools import cached_property +from functools import cached_property, reduce from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, Optional) from typing import Sequence as GenericSequence @@ -22,7 +22,7 @@ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: - from vllm.inputs import EncoderDecoderInputs, LLMInputs + from vllm.inputs import DecoderOnlyInputs, EncoderDecoderInputs from vllm.multimodal.base import MultiModalDataDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -324,7 +324,19 @@ class SequenceTokenData(SequenceDataMixin, _TokenBase, ...] = msgspec.field(default_factory=tuple) @staticmethod - def from_seq( + def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceTokenData": + if len(counts_by_token) == 0: + return SequenceTokenData.from_seqs([]) + + arrs = [ + array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count + for token_id, count in counts_by_token.items() + ] + + return SequenceTokenData(reduce(lambda a, b: a + b, arrs)) + + @staticmethod + def from_seqs( prompt_token_ids: GenericSequence[int], output_token_ids: Optional[GenericSequence[int]] = None, ) -> "SequenceTokenData": @@ -475,7 +487,7 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. - The sequence is constructed from the LLMInputs (for decoder-only) or + The sequence is constructed from the DecoderOnlyInputs (for decoder-only) or EncoderDecoderInputs (for encoder-decoder) instance passed in through the `inputs` constructor argument. @@ -493,7 +505,7 @@ class Sequence: def __init__( self, seq_id: int, - inputs: Union["LLMInputs", "EncoderDecoderInputs"], + inputs: Union["DecoderOnlyInputs", "EncoderDecoderInputs"], block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, @@ -508,7 +520,7 @@ def __init__( data: SequenceData if self.prompt_token_ids: - data = SequenceTokenData.from_seq(self.prompt_token_ids) + data = SequenceTokenData.from_seqs(self.prompt_token_ids) else: assert self.prompt_embeds is not None data = SequenceEmbedData(self.prompt_embeds) From 60bc7b5c9e8cf830d24a4992ecef0fd7874d7e03 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 16:30:51 +0000 Subject: [PATCH 39/88] Fix error on import --- vllm/sequence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 77ec1c762d2d6..c80fb4840cd51 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -294,7 +294,7 @@ def stage(self) -> SequenceStage: return self._stage -class _TokenBase(msgspec.Struct, omit_defaults=True): # type: ignore[call-arg] +class _TokenBase: # NOTE: we cannot use Union[List, array] because msgspec cannot support # union of 2 list types. _prompt_token_ids: array @@ -410,7 +410,7 @@ def __repr__(self) -> str: f"get_num_computed_tokens={self.get_num_computed_tokens()}") -class _EmbedBase(msgspec.Struct, omit_defaults=True): # type: ignore[call-arg] +class _EmbedBase: _prompt_embeds: torch.Tensor From a8483a4acb031770f81c8a3197dcc9f4a51704b7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 16:39:11 +0000 Subject: [PATCH 40/88] Revert class splitting --- tests/samplers/test_sampler.py | 15 +- tests/spec_decode/utils.py | 8 +- tests/test_logits_processor.py | 5 +- tests/test_sequence.py | 6 +- .../test_encoder_decoder_model_runner.py | 15 +- tests/worker/test_model_runner.py | 19 +- vllm/inputs/registry.py | 12 +- vllm/model_executor/models/blip.py | 4 +- vllm/model_executor/models/blip2.py | 4 +- vllm/model_executor/models/chameleon.py | 4 +- vllm/model_executor/models/clip.py | 4 +- vllm/model_executor/models/fuyu.py | 4 +- vllm/model_executor/models/minicpmv.py | 4 +- vllm/model_executor/models/pixtral.py | 4 +- vllm/model_executor/models/qwen.py | 8 +- vllm/model_executor/models/qwen2_vl.py | 6 +- vllm/model_executor/models/siglip.py | 4 +- vllm/model_executor/models/ultravox.py | 4 +- vllm/sequence.py | 302 ++++++------------ 19 files changed, 154 insertions(+), 278 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index e8f5e83e01aac..308b708feab71 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -11,8 +11,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (SamplingParams, SequenceGroupMetadata, - SequenceTokenData) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import Counter, is_pin_memory_available @@ -58,7 +57,7 @@ def _do_sample( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceTokenData.from_seqs([1, 2, 3])}, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -202,7 +201,7 @@ def create_sampling_params(min_tokens, return sampling_params def create_sequence_data(num_input=3, num_generated=0): - seq_data = SequenceTokenData.from_seqs( + seq_data = SequenceData.from_seqs( random.choices(range(0, VOCAB_SIZE), k=num_input)) if num_generated > 0: seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), @@ -234,7 +233,7 @@ def generate_test_case(): eos_token_id=eos_token_id, stop_token_ids=stop_token_ids) - seq_data: Dict[int, SequenceTokenData] = {} + seq_data: Dict[int, SequenceData] = {} seq_group_penalization: List[bool] = [] for _ in range(num_seqs): num_input = random.randint(1, 100) @@ -507,7 +506,7 @@ def test_sampler_mixed(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceTokenData.from_seqs([1, 2, 3])}, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -607,7 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceTokenData.from_seqs([1, 2, 3])}, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=SamplingParams( temperature=1, top_k=top_k, @@ -691,7 +690,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceTokenData.from_seqs([1, 2, 3])}, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=sampling_params[i], block_tables={0: [1]}, )) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 6183df42961ee..f17e872881633 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -11,8 +11,7 @@ from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceGroupMetadata, SequenceOutput, - SequenceTokenData) + SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner @@ -137,9 +136,8 @@ def create_seq_group_metadata_from_prompts( request_id=str(i), is_prompt=len(cont_token_ids) == 0, seq_data={ - i: - SequenceTokenData.from_seqs(prompt_token_ids[:], - cont_token_ids[:]), + i: SequenceData.from_seqs(prompt_token_ids[:], + cont_token_ids[:]), }, sampling_params=SamplingParams(temperature=0.0, ), block_tables={i: block_allocations[i][:]}, diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 148e663b2b799..39c1c38151fd0 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -8,8 +8,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (SamplingParams, SequenceGroupMetadata, - SequenceTokenData) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import is_pin_memory_available @@ -70,7 +69,7 @@ def pick_ith(token_ids, logits): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceTokenData.from_seqs([1, 2, 3])}, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 3afdc1bf64068..30e53a180ea31 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,8 +1,8 @@ import pytest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (CompletionSequenceGroupOutput, SequenceOutput, - SequenceTokenData) +from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData, + SequenceOutput) from .core.utils import create_dummy_prompt @@ -55,7 +55,7 @@ def test_sampler_output_eq(sample_outputs): def test_sequence_data_prefill(): - seq_data = SequenceTokenData.from_seqs([1, 2, 3, 4]) + seq_data = SequenceData.from_seqs([1, 2, 3, 4]) assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_computed_tokens() == 0 # advance by 2 diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 5f5df143e738c..3dccc1b325d95 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -5,8 +5,7 @@ import torch from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import (SamplingParams, SequenceGroupMetadata, - SequenceTokenData) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import is_cpu, make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import _get_graph_batch_size @@ -118,10 +117,10 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceTokenData.from_seqs(range(seq_len)) + seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceTokenData.from_seqs(range(encoder_seq_len)) + encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -314,9 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceTokenData.from_seqs(range(seq_len)) + seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceTokenData.from_seqs(range(encoder_seq_len)) + encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -518,9 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceTokenData.from_seqs(range(seq_len)) + seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceTokenData.from_seqs(range(encoder_seq_len)) + encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 5d177eb326a47..55cf481e8bb63 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -8,8 +8,7 @@ init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (SamplingParams, SequenceGroupMetadata, - SequenceTokenData) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -51,13 +50,13 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio): seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceTokenData.from_seqs( + seq_data = SequenceData.from_seqs( range(seq_len), torch.rand(seq_len, 10).tolist(), ) input_embeds_len += seq_len else: - seq_data = SequenceTokenData.from_seqs(range(seq_len)) + seq_data = SequenceData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -184,13 +183,13 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceTokenData.from_seqs( + seq_data = SequenceData.from_seqs( [], torch.rand(context_len, 10).tolist(), ) input_embeds_len += context_len else: - seq_data = SequenceTokenData.from_seqs(range(context_len)) + seq_data = SequenceData.from_seqs(range(context_len)) seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. seq_data.append_token_id(1, 0) @@ -364,13 +363,13 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) if random.random() < prompt_embeds_ratio: - seq_data = SequenceTokenData.from_seqs( + seq_data = SequenceData.from_seqs( [], torch.rand(seq_len, 10).tolist(), ) input_embeds_len += seq_len else: - seq_data = SequenceTokenData.from_seqs(range(seq_len)) + seq_data = SequenceData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -387,12 +386,12 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 if random.random() < prompt_embeds_ratio: - seq_data = SequenceTokenData.from_seqs( + seq_data = SequenceData.from_seqs( [], torch.rand(context_len, 10).tolist(), ), else: - seq_data = SequenceTokenData.from_seqs(range(context_len)) + seq_data = SequenceData.from_seqs(range(context_len)) seq_data.append_token_id(1, 0) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 57265b151025c..6331a5e9971ef 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.multimodal import MultiModalDataDict, MultiModalRegistry - from vllm.sequence import SequenceTokenData + from vllm.sequence import SequenceData logger = init_logger(__name__) @@ -72,7 +72,7 @@ def __call__( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - ) -> Tuple["SequenceTokenData", Optional["MultiModalDataDict"]]: + ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data to be inputted into the model. @@ -118,7 +118,7 @@ def _default_dummy_data_factory( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - ) -> Tuple["SequenceTokenData", Optional["MultiModalDataDict"]]: + ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ The default dummy data factory represents the longest possible text that can be inputted to the model. @@ -127,9 +127,9 @@ def _default_dummy_data_factory( :data:`InputProcessor` is not applied to the dummy data. """ # Avoid circular import - from vllm.sequence import SequenceTokenData + from vllm.sequence import SequenceData - dummy_seq_data = SequenceTokenData.from_seqs([0] * seq_len) + dummy_seq_data = SequenceData.from_seqs([0] * seq_len) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data @@ -161,7 +161,7 @@ def dummy_data_for_profiling( model_config: "ModelConfig", seq_len: int, mm_registry: "MultiModalRegistry", - ) -> Tuple["SequenceTokenData", Optional["MultiModalDataDict"]]: + ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data for profiling the memory usage of a model. diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 168ec822858cd..8294ac044e8d6 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import SequenceTokenData +from vllm.sequence import SequenceData try: from xformers import ops as xops @@ -61,7 +61,7 @@ def dummy_seq_data_for_blip( else: image_feature_size = image_feature_size_override - return SequenceTokenData.from_counts({ + return SequenceData.from_counts({ image_token_id: image_feature_size, 0: (seq_len - image_feature_size), }) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 7ebf2c529c434..62fc1676bca7e 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -18,7 +18,7 @@ from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, SequenceTokenData +from vllm.sequence import IntermediateTensors, SequenceData from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) @@ -428,7 +428,7 @@ def dummy_seq_data_for_blip2( else: image_feature_size = image_feature_size_override - return SequenceTokenData.from_counts({ + return SequenceData.from_counts({ image_token_id: image_feature_size * num_images, 0: diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index ef7a326fcac91..271a27bb8134d 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -32,7 +32,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import IntermediateTensors, SequenceTokenData +from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal @@ -71,7 +71,7 @@ def dummy_seq_data_for_chameleon( else: image_feature_size = image_feature_size_override - return SequenceTokenData.from_counts({ + return SequenceData.from_counts({ image_token_id: image_feature_size * num_images, 0: diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 8ea02ea6a84ac..e69f19eb45253 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -19,7 +19,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import SequenceTokenData +from vllm.sequence import SequenceData try: from xformers import ops as xops @@ -61,7 +61,7 @@ def dummy_seq_data_for_clip( else: image_feature_size = image_feature_size_override - return SequenceTokenData.from_counts({ + return SequenceData.from_counts({ image_token_id: image_feature_size * num_images, 0: diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 50f80a76a8b0f..d04b5b44bc9d3 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -41,7 +41,7 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceTokenData) + SequenceData) from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings @@ -107,7 +107,7 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - image_feature_size * num_images) - return SequenceTokenData.from_seqs(token_ids) + return SequenceData.from_seqs(token_ids) def dummy_image_for_fuyu( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 8b79636101a88..8a861b4ef9aaf 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -56,7 +56,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SequenceTokenData +from vllm.sequence import IntermediateTensors, SequenceData from .idefics2_vision_model import Idefics2VisionTransformer @@ -258,7 +258,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): - return SequenceTokenData.from_counts({ + return SequenceData.from_counts({ 0: seq_len, }) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index b236e36f62c09..98fe6492343be 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -23,7 +23,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SequenceTokenData +from vllm.sequence import IntermediateTensors, SequenceData from .interfaces import SupportsMultiModal from .utils import init_vllm_registered_model @@ -61,7 +61,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, image_feature_size = (size**2) // (patch_size**2) num_image_tokens = image_feature_size * num_images - seq_data = SequenceTokenData.from_counts({ + seq_data = SequenceData.from_counts({ image_token_id: num_image_tokens, 0: seq_len - num_image_tokens, }) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d284ee2862547..a14f6b5d668dd 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -45,7 +45,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SequenceTokenData +from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of from .utils import flatten_bn, is_pp_missing_parameter, make_layers @@ -801,7 +801,7 @@ def dummy_data_for_qwen( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], -) -> Tuple[SequenceTokenData, Optional[Dict]]: +) -> Tuple[SequenceData, Optional[Dict]]: """Build dummy data for warming up Qwen models; this will only contain text matching the defaults for VLLM unless the model has a visual config. @@ -818,7 +818,7 @@ def dummy_data_for_qwen( # The presence of a visual config indicates this is a multimodal model. # If we don't have it, the model is considered an LLM for warmup purposes. if not hasattr(hf_config, "visual"): - seq_data = SequenceTokenData.from_counts({0: seq_len}) + seq_data = SequenceData.from_counts({0: seq_len}) mm_data = None return seq_data, mm_data @@ -845,7 +845,7 @@ def dummy_data_for_qwen( if len(toks) < seq_len: toks += [0] * (seq_len - len(toks)) - seq_data = SequenceTokenData.from_seqs(toks) + seq_data = SequenceData.from_seqs(toks) # Build the input images; width/height doesn't actually matter here since # the data will get resized and the # of tokens per image is constant diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0d97a5958164c..6a6d8e7a4cff4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -66,7 +66,7 @@ from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors, SequenceTokenData +from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import get_processor logger = init_logger(__name__) @@ -655,7 +655,7 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: def dummy_data_for_qwen2_vl( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] -) -> Tuple[SequenceTokenData, Optional[MultiModalDataDict]]: +) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: image_processor = cached_get_image_processor(ctx.model_config.model) num_images = mm_counts["image"] @@ -681,7 +681,7 @@ def dummy_data_for_qwen2_vl( hf_config = ctx.get_hf_config(Qwen2VLConfig) - dummy_seqdata = SequenceTokenData.from_counts({ + dummy_seqdata = SequenceData.from_counts({ hf_config.vision_start_token_id: 1, hf_config.image_token_id: diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 1c89682cd76b5..1d56f338d5da1 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import SequenceTokenData +from vllm.sequence import SequenceData try: from xformers import ops as xops @@ -66,7 +66,7 @@ def dummy_seq_data_for_siglip( else: image_feature_size = image_feature_size_override - return SequenceTokenData.from_counts({ + return SequenceData.from_counts({ image_token_id: image_feature_size, 0: (seq_len - image_feature_size), }) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 81e047514f1ea..f1e4b9e3cabee 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -37,7 +37,7 @@ from vllm.multimodal.base import MultiModalInputs, NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceTokenData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.transformers_utils.configs.ultravox import UltravoxConfig _AUDIO_PLACEHOLDER_TOKEN = 128002 @@ -96,7 +96,7 @@ def dummy_data_for_ultravox( other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids)) - seq_data = SequenceTokenData.from_seqs(audio_token_ids + other_token_ids) + seq_data = SequenceData.from_seqs(audio_token_ids + other_token_ids) audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) mm_dict = {"audio": [audio_and_sr] * audio_count} diff --git a/vllm/sequence.py b/vllm/sequence.py index c80fb4840cd51..be20744cd497d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,8 +6,7 @@ from collections import defaultdict from dataclasses import dataclass from functools import cached_property, reduce -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, - Optional) +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence from typing import Set, Tuple, Union, cast @@ -135,50 +134,111 @@ class SequenceDataDelta( new_stage: SequenceStage -class SequenceDataMixin(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] +class SequenceData(msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] """Data associated with a sequence. - Args: + prompt_token_ids: The token IDs of the prompt. + prompt_embeds: The embeddings of the prompt. output_token_ids: The token IDs of the output. Set to an empty list if None. - Attributes: + prompt_token_ids: The token IDs of the prompt. + prompt_embeds: The embeddings of the prompt. output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. """ + # NOTE: we cannot use Union[List, array] because msgspec cannot support + # union of 2 list types. + _prompt_token_ids: array + _prompt_embeds: Optional[List[torch.Tensor]] = None _output_token_ids: array = msgspec.field( default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 - + _prompt_token_ids_tuple: Tuple[int, + ...] = msgspec.field(default_factory=tuple) # The number of tokens that are computed (that run against the model). _num_computed_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL _cached_all_token_ids: List[int] = msgspec.field(default_factory=list) - # It is used to get delta input. It is reset when `get_delta_and_reset` # is called. _new_appended_tokens: List[int] = msgspec.field(default_factory=list) - # It is used to compute mrope_position_ids. _mrope_position_delta: Optional[int] = None + @staticmethod + def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData": + if len(counts_by_token) == 0: + return SequenceData.from_seqs([]) + + arrs = [ + array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count + for token_id, count in counts_by_token.items() + ] + + return SequenceData(reduce(lambda a, b: a + b, arrs)) + + @staticmethod + def from_seqs( + prompt_token_ids: GenericSequence[int], + output_token_ids: Optional[GenericSequence[int]] = None, + ) -> "SequenceData": + prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + prompt_token_ids) + + if output_token_ids is None: + return SequenceData(prompt_token_ids_arr) + + output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + output_token_ids) + + return SequenceData(prompt_token_ids_arr, + _output_token_ids=output_token_ids_arr) + def __post_init__(self) -> None: + assert self._prompt_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l" - + self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( + self._prompt_token_ids) self._update_cached_all_tokens() - def _update_cached_all_tokens(self) -> None: + def _update_cached_all_tokens(self): + assert isinstance(self._prompt_token_ids, array) assert isinstance(self._output_token_ids, array) - - self._cached_all_token_ids = list(self._output_token_ids) + self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + + self._output_token_ids) @property def cumulative_logprob(self) -> float: return self._cumulative_logprob + @property + def prompt_token_ids(self) -> Tuple[int, ...]: + return self._prompt_token_ids_tuple + + @prompt_token_ids.setter + def prompt_token_ids(self, new_prompt_token_ids) -> None: + raise NotImplementedError + + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + return self._prompt_embeds + + @prompt_embeds.setter + def prompt_embeds(self, new_prompt_embeds: Optional[torch.Tensor]) -> None: + self._prompt_embeds = new_prompt_embeds + + @property + def prompt_token_ids_array(self) -> array: + """Return the prompt token ids in array type. + Note that the array is in "I" type, and it is not compatible + with torch.long (2 bytes vs 4 bytes). So beware of the usage. + """ + return self._prompt_token_ids + @property def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @@ -214,11 +274,13 @@ def append_token_id(self, token_id: int, logprob: float) -> None: self._cumulative_logprob += logprob def get_len(self) -> int: - return self.get_prompt_len() + len(self._output_token_ids) + if self._prompt_embeds is None: + return len(self._output_token_ids) + len(self._prompt_token_ids) + else: + return len(self._output_token_ids) + len(self._prompt_embeds) - @abstractmethod def get_prompt_len(self) -> int: - raise NotImplementedError + return len(self._prompt_token_ids) def get_output_len(self) -> int: return len(self._output_token_ids) @@ -226,13 +288,16 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self._cached_all_token_ids - @abstractmethod def get_prefix_token_ids( - self, - num_tokens: int, + self, num_tokens: int ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: """Get prefix tokens, and make the return value hashable""" - raise NotImplementedError + prompt_length = self.get_prompt_len() + if num_tokens > prompt_length: + return (self._prompt_token_ids_tuple, + tuple(self._output_token_ids[:num_tokens - prompt_length])) + else: + return (self._prompt_token_ids_tuple[:num_tokens], None) def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" @@ -263,13 +328,13 @@ def get_num_uncomputed_tokens(self) -> int: # prefill for both prompt and output. return self.get_len() - self.get_num_computed_tokens() - @abstractmethod def get_last_token_id(self) -> int: - raise NotImplementedError + if not self._output_token_ids: + return self._prompt_token_ids[-1] + return self._output_token_ids[-1] - @abstractmethod def get_prompt_token_ids(self) -> Tuple[int, ...]: - raise NotImplementedError + return self.prompt_token_ids def get_output_token_ids(self) -> Tuple[int, ...]: return self.output_token_ids @@ -293,197 +358,14 @@ def apply_delta(self, delta: SequenceDataDelta): def stage(self) -> SequenceStage: return self._stage - -class _TokenBase: - # NOTE: we cannot use Union[List, array] because msgspec cannot support - # union of 2 list types. - _prompt_token_ids: array - - -# Note the ordering here so `_prompt_token_ids` is the first __init__ argument -class SequenceTokenData(SequenceDataMixin, _TokenBase, - omit_defaults=True): # type: ignore[call-arg] - """Data associated with a sequence. - - Args: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. Set to an empty list if - None. - - Attributes: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. - cumulative_logprob: The cumulative log probability of the output. - """ - - # For tagged union - type: Literal["tokens"] = "tokens" - - ### The below fields should not be passed as an argument ### - _prompt_token_ids_tuple: Tuple[int, - ...] = msgspec.field(default_factory=tuple) - - @staticmethod - def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceTokenData": - if len(counts_by_token) == 0: - return SequenceTokenData.from_seqs([]) - - arrs = [ - array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count - for token_id, count in counts_by_token.items() - ] - - return SequenceTokenData(reduce(lambda a, b: a + b, arrs)) - - @staticmethod - def from_seqs( - prompt_token_ids: GenericSequence[int], - output_token_ids: Optional[GenericSequence[int]] = None, - ) -> "SequenceTokenData": - prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - prompt_token_ids) - - if output_token_ids is None: - return SequenceTokenData(prompt_token_ids_arr) - - output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - output_token_ids) - - return SequenceTokenData(prompt_token_ids_arr, output_token_ids_arr) - - def __post_init__(self) -> None: - assert self._prompt_token_ids.typecode == "l" - assert self._output_token_ids.typecode == "l" - - self._prompt_token_ids_tuple = tuple(self._prompt_token_ids) - - self._update_cached_all_tokens() - - def _update_cached_all_tokens(self) -> None: - assert isinstance(self._prompt_token_ids, array) - assert isinstance(self._output_token_ids, array) - - self._cached_all_token_ids = list(self._prompt_token_ids + - self._output_token_ids) - - @property - def prompt_token_ids(self) -> Tuple[int, ...]: - return self._prompt_token_ids_tuple - - @property - def prompt_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - return self._prompt_token_ids - - def get_prompt_len(self) -> int: - return len(self._prompt_token_ids) - - def get_last_token_id(self) -> int: - if not self._output_token_ids: - return self._prompt_token_ids[-1] - - return self._output_token_ids[-1] - - def get_prompt_token_ids(self) -> Tuple[int, ...]: - return self.prompt_token_ids - - def get_prefix_token_ids( - self, - num_tokens: int, - ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: - prompt_length = self.get_prompt_len() - if num_tokens > prompt_length: - return (self._prompt_token_ids_tuple, - tuple(self._output_token_ids[:num_tokens - prompt_length])) - - return (self._prompt_token_ids_tuple[:num_tokens], None) - def __repr__(self) -> str: - return (f"SequenceTokenData(" + return (f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " f"output_token_ids={self.output_token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()}") -class _EmbedBase: - _prompt_embeds: torch.Tensor - - -# Note the ordering here so `_prompt_embeds` is the first __init__ argument -class SequenceEmbedData(SequenceDataMixin, _EmbedBase, - omit_defaults=True): # type: ignore[call-arg] - """Data associated with a sequence. - - Args: - prompt_embeds: The embeddings of the prompt. - output_token_ids: The token IDs of the output. Set to an empty list if - None. - cumulative_logprob: The cumulative log probability of the output. - - Attributes: - prompt_embeds: The embeddings of the prompt. - output_token_ids: The token IDs of the output. - cumulative_logprob: The cumulative log probability of the output. - """ - - # For tagged union - type: Literal["embeds"] = "embeds" - - ### The below fields should not be passed as an argument ### - _dummy_token_ids_tuple: Tuple[int, - ...] = msgspec.field(default_factory=tuple) - - def __post_init__(self) -> None: - assert self._output_token_ids.typecode == "l" - - # Dummy value - self._dummy_token_ids_tuple = tuple([0] * self._prompt_embeds) - - self._update_cached_all_tokens() - - @property - def prompt_embeds(self) -> torch.Tensor: - return self._prompt_embeds - - def get_prompt_len(self) -> int: - return len(self._prompt_embeds) - - def get_last_token_id(self) -> int: - if not self._output_token_ids: - return self._dummy_token_ids_tuple[-1] - - return self._output_token_ids[-1] - - def get_prompt_token_ids(self) -> Tuple[int, ...]: - return self._dummy_token_ids_tuple - - def get_prefix_token_ids( - self, - num_tokens: int, - ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: - prompt_length = self.get_prompt_len() - if num_tokens > prompt_length: - return (self._dummy_token_ids_tuple, - tuple(self._output_token_ids[:num_tokens - prompt_length])) - - return (self._dummy_token_ids_tuple[:num_tokens], None) - - def __repr__(self) -> str: - return (f"SequenceEmbedData(" - f"prompt_embeds={self._prompt_embeds}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"get_num_computed_tokens={self.get_num_computed_tokens()}") - - -SequenceData = Union[SequenceTokenData, SequenceEmbedData] - - class Sequence: """Stores the data, status, and block information of a sequence. @@ -520,10 +402,10 @@ def __init__( data: SequenceData if self.prompt_token_ids: - data = SequenceTokenData.from_seqs(self.prompt_token_ids) + data = SequenceData.from_seqs(self.prompt_token_ids) else: assert self.prompt_embeds is not None - data = SequenceEmbedData(self.prompt_embeds) + data = SequenceData(self.prompt_embeds) self.data = data From f451192f0eb60f52112db99559599ceeacc2d0e8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 16:42:05 +0000 Subject: [PATCH 41/88] Fix init error for embeds --- vllm/sequence.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index be20744cd497d..e897aa4f539b6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -151,7 +151,7 @@ class SequenceData(msgspec.Struct, # NOTE: we cannot use Union[List, array] because msgspec cannot support # union of 2 list types. _prompt_token_ids: array - _prompt_embeds: Optional[List[torch.Tensor]] = None + _prompt_embeds: Optional[torch.Tensor] = None _output_token_ids: array = msgspec.field( default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) @@ -185,18 +185,22 @@ def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData": def from_seqs( prompt_token_ids: GenericSequence[int], output_token_ids: Optional[GenericSequence[int]] = None, + *, + prompt_embeds: Optional[torch.Tensor] = None, ) -> "SequenceData": prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids) if output_token_ids is None: - return SequenceData(prompt_token_ids_arr) + return SequenceData(prompt_token_ids_arr, + _prompt_embeds=prompt_embeds) output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, output_token_ids) return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr) + _output_token_ids=output_token_ids_arr, + _prompt_embeds=prompt_embeds) def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" @@ -405,7 +409,8 @@ def __init__( data = SequenceData.from_seqs(self.prompt_token_ids) else: assert self.prompt_embeds is not None - data = SequenceData(self.prompt_embeds) + data = SequenceData.from_seqs([], + prompt_embeds=self.prompt_embeds) self.data = data From 3d436e92cfefd8645ff2e7ad81b1a8f71841b296 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 16:44:07 +0000 Subject: [PATCH 42/88] Be a bit more efficient --- vllm/sequence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index e897aa4f539b6..3ed651bc1c38e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -179,7 +179,7 @@ def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData": for token_id, count in counts_by_token.items() ] - return SequenceData(reduce(lambda a, b: a + b, arrs)) + return SequenceData(reduce(array.__add__, arrs)) @staticmethod def from_seqs( From dfecf4b2a40d659a9003ba07733f7e66f9b2dbc4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 16:47:31 +0000 Subject: [PATCH 43/88] Rename --- vllm/entrypoints/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0456d8a0d72a5..ca8d144286eea 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -258,8 +258,8 @@ def generate( @overload def generate( self, - inputs: Union[PromptType, Sequence[PromptType]], - /, # We may enable `inputs` keyword after removing the old API + prompts: Union[PromptType, Sequence[PromptType]], + /, # We may enable `prompts` keyword after removing the old API *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, From 29b9d5c13665511073282c52e3f90a33fc22b690 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 16:48:55 +0000 Subject: [PATCH 44/88] Rename 2 --- vllm/entrypoints/llm.py | 32 +++++++++++++++++--------------- vllm/sequence.py | 3 +-- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ca8d144286eea..5acc400615c37 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -320,12 +320,13 @@ def generate( "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptType, Sequence[PromptType]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -340,7 +341,7 @@ def generate( sampling_params = SamplingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -494,7 +495,7 @@ def encode( @overload def encode( self, - inputs: Union[PromptType, Sequence[PromptType]], + prompts: Union[PromptType, Sequence[PromptType]], /, # We may enable `inputs` keyword after removing the old API *, pooling_params: Optional[Union[PoolingParams, @@ -553,19 +554,20 @@ def encode( ) if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptType, Sequence[PromptType]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -609,7 +611,7 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - inputs: List[PromptType] = [] + parsed_prompts: List[PromptType] = [] for i in range(num_requests): item: PromptType @@ -620,24 +622,24 @@ def _convert_v1_inputs( else: raise AssertionError - inputs.append(item) + parsed_prompts.append(item) - return inputs + return parsed_prompts def _validate_and_add_requests( self, - inputs: Union[PromptType, Sequence[PromptType]], + prompts: Union[PromptType, Sequence[PromptType]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, ) -> None: - if isinstance(inputs, (str, dict)): + if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. - inputs = [inputs] + prompts = [prompts] - num_requests = len(inputs) + num_requests = len(prompts) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -654,7 +656,7 @@ def _validate_and_add_requests( sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - for i, request_inputs in enumerate(inputs): + for i, request_inputs in enumerate(prompts): self._add_request( request_inputs, params[i] if isinstance(params, Sequence) else params, diff --git a/vllm/sequence.py b/vllm/sequence.py index 3ed651bc1c38e..6604671ed1e15 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -409,8 +409,7 @@ def __init__( data = SequenceData.from_seqs(self.prompt_token_ids) else: assert self.prompt_embeds is not None - data = SequenceData.from_seqs([], - prompt_embeds=self.prompt_embeds) + data = SequenceData.from_seqs([], prompt_embeds=self.prompt_embeds) self.data = data From d5ec13a40c4af13864a106ee44fd3804389769a3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 16:53:50 +0000 Subject: [PATCH 45/88] Fix encoder-decoder test prompts --- tests/core/utils.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 40d8f51fc186e..d7459b90ab5f9 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -4,6 +4,7 @@ from typing import Tuple from vllm import SamplingParams +from vllm.inputs import EncoderDecoderInputs from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, Sequence, SequenceGroup @@ -62,23 +63,27 @@ def create_dummy_prompt_encoder_decoder( encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - inputs = { - "prompt": decoder_prompt_str, - "prompt_token_ids": decoder_prompt_tokens, - "encoder_prompt": encoder_prompt_str, - "encoder_prompt_token_ids": encoder_prompt_tokens, - "multi_modal_data": None, + inputs: EncoderDecoderInputs = { + "decoder": { + "type": "token", + "prompt": decoder_prompt_str, + "prompt_token_ids": decoder_prompt_tokens, + }, + "encoder": { + "type": "token", + "prompt": encoder_prompt_str, + "prompt_token_ids": encoder_prompt_tokens, + }, } decoder_prompt = Sequence(int(request_id), inputs=inputs, - block_size=block_size, - from_decoder_prompt=True) + block_size=block_size) encoder_prompt = Sequence(int(request_id), inputs=inputs, - block_size=block_size, - from_decoder_prompt=False) + block_size=block_size) + seq_group = SequenceGroup(request_id=request_id, seqs=[decoder_prompt], sampling_params=SamplingParams( From b5632f9b5f972720cfe09aa07fa6ae821e7b8080 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 17:01:57 +0000 Subject: [PATCH 46/88] Fix validation for encoder-decoder models --- vllm/engine/llm_engine.py | 8 ++++---- vllm/inputs/parse.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2fd2d2b012b1b..3989fbd1be2b3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1707,12 +1707,12 @@ def _support_prompt_embeds(self) -> Tuple[bool, str]: def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]): - if self.is_encoder_decoder_model(): - prompt_ids = inputs.get("encoder_prompt_token_ids") + if is_valid_encoder_decoder_inputs(inputs): + prompt_ids = inputs["encoder"].get("prompt_token_ids") + prompt_embeds = inputs["encoder"].get("prompt_embeds") else: prompt_ids = inputs.get("prompt_token_ids") - - prompt_embeds = inputs.get("prompt_embeds") + prompt_embeds = inputs.get("prompt_embeds") if prompt_ids is None: if prompt_embeds is None: diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 6c5be9007bced..3d36c60a42478 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -112,4 +112,4 @@ def is_explicit_encoder_decoder_prompt( def is_valid_encoder_decoder_inputs( inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], ) -> TypeIs[EncoderDecoderInputs]: - return "encoder_prompt_token_ids" in inputs + return "encoder" in inputs and "decoder" in inputs From 3da5ad6ddaceac6f5ebcec6f78fddde0950e6a17 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 17:06:24 +0000 Subject: [PATCH 47/88] Fix naming --- vllm/engine/multiprocessing/client.py | 12 +++++------- vllm/inputs/preprocess.py | 26 +++++++++++++------------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 7fbd4ccddfe8b..71099115ea125 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -389,8 +389,7 @@ def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -399,13 +398,13 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ - return self._process_request(inputs, sampling_params, request_id, + return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -418,8 +417,7 @@ def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -430,7 +428,7 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ - return self._process_request(inputs, pooling_params, request_id, + return self._process_request(prompt, pooling_params, request_id, lora_request, trace_headers) async def _process_request( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 2c80c5b13002a..92a25199821fd 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -427,20 +427,20 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, - inputs: PromptType, + prompt: PromptType, request_id: str, ) -> EncoderDecoderInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_inputs: DecoderOnlyInputs decoder_inputs: Union[EmptyInputs, DecoderOnlyInputs] - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_task = self._prompt_to_llm_inputs_async( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: encoder_inputs = await encoder_task decoder_inputs = empty_inputs() else: @@ -453,7 +453,7 @@ async def _process_encoder_decoder_prompt_async( encoder_task, decoder_task) else: encoder_inputs = await self._prompt_to_llm_inputs_async( - inputs, + prompt, request_id=request_id, ) @@ -534,7 +534,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - inputs: PromptType, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -544,17 +544,17 @@ def preprocess( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return self._process_decoder_only_prompt( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -562,7 +562,7 @@ def preprocess( async def preprocess_async( self, - inputs: PromptType, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -572,17 +572,17 @@ async def preprocess_async( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return await self._process_decoder_only_prompt_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, From 94ace38b340c7d77247ed1f1e207f0f05a16d949 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 17:22:53 +0000 Subject: [PATCH 48/88] Rename `PromptInputs` to `PromptType`, and `inputs` to `prompt` --- benchmarks/benchmark_latency.py | 8 +- .../dev/multimodal/multimodal_index.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 2 +- docs/source/models/vlm.rst | 2 +- tests/mq_llm_engine/test_error_handling.py | 12 +-- tests/mq_llm_engine/utils.py | 2 +- vllm/__init__.py | 4 +- vllm/engine/async_llm_engine.py | 24 +++--- vllm/engine/llm_engine.py | 9 +- vllm/engine/multiprocessing/__init__.py | 4 +- vllm/engine/multiprocessing/client.py | 20 ++--- vllm/engine/multiprocessing/engine.py | 2 +- vllm/engine/protocol.py | 8 +- vllm/entrypoints/llm.py | 76 ++++++++-------- vllm/inputs/__init__.py | 6 +- vllm/inputs/data.py | 26 +++--- vllm/inputs/parse.py | 22 ++--- vllm/inputs/preprocess.py | 86 +++++++++---------- 18 files changed, 155 insertions(+), 160 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a39d1cf842f06..eadf994cacd34 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptInputs] = [{ + dummy_prompts: List[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 241b2ccd0991e..e112b43aade5e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 9adf82d43f3e0..0d47281db485e 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptInputs +.. autodata:: vllm.inputs.PromptType .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 08db891665044..ca5b125369c85 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 49cfc5aa04c36..7c466c92d5293 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket): # Throws an error in first forward pass. with pytest.raises(RAISED_ERROR): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket): # Engine is errored, should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket): # Generate call should throw ENGINE_DEAD_ERROR with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -165,7 +165,7 @@ async def bad_abort_after_2s(): # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( - inputs="Hello my name is", + prompt="Hello my name is", sampling_params=SamplingParams(max_tokens=2000), request_id=uuid.uuid4()): pass @@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket): # Invalid request should fail, but not crash the server. with pytest.raises(ValueError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( @@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket): pass # This request should be okay. - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-2"): pass diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index e27fd77923412..3ffa126070ca0 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -20,7 +20,7 @@ async def generate( count = 0 async for out in client.generate( request_id=request_id, - inputs="Hello my name is Robert and", + prompt="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): diff --git a/vllm/__init__.py b/vllm/__init__.py index 0895c571d1d89..59af68fb493e5 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,7 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -19,7 +19,7 @@ "__version__", "LLM", "ModelRegistry", - "PromptInputs", + "PromptType", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 34e7e05341f02..f108751056ab5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -17,7 +17,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -405,7 +405,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None: async def add_request_async( self, request_id: str, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -420,7 +420,7 @@ async def add_request_async( arrival_time = time.time() preprocessed_inputs = await self.input_preprocessor.preprocess_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -777,7 +777,7 @@ async def run_engine_loop(engine_ref: ReferenceType): async def add_request( self, request_id: str, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -797,7 +797,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, - inputs=inputs, + prompt=prompt, params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, @@ -808,7 +808,7 @@ async def add_request( async def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -822,8 +822,7 @@ async def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -881,7 +880,7 @@ async def generate( """ async for output in await self.add_request( request_id, - inputs, + prompt, sampling_params, lora_request=lora_request, trace_headers=trace_headers, @@ -891,7 +890,7 @@ async def generate( async def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -904,8 +903,7 @@ async def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -959,7 +957,7 @@ async def encode( """ async for output in await self.add_request( request_id, - inputs, + prompt, pooling_params, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2743d5c7d2282..39409757d3812 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -29,7 +29,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptInputs) + InputRegistry, LLMInputs, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -680,7 +680,7 @@ def stop_remote_worker_execution_loop(self) -> None: def add_request( self, request_id: str, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -695,8 +695,7 @@ def add_request( Args: request_id: The unique ID of the request. - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. @@ -736,7 +735,7 @@ def add_request( arrival_time = time.time() preprocessed_inputs = self.input_preprocessor.preprocess( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 700332864d17a..09aa279f1e22c 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -3,7 +3,7 @@ from typing import List, Mapping, Optional, Union from vllm import PoolingParams -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError): @dataclass class RPCProcessRequest: - inputs: PromptInputs + prompt: PromptType params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index aa9dbbd448af2..71099115ea125 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -25,7 +25,7 @@ RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -375,7 +375,7 @@ def dead_error(self) -> BaseException: def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -389,8 +389,7 @@ def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -399,13 +398,13 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ - return self._process_request(inputs, sampling_params, request_id, + return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -418,8 +417,7 @@ def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -430,12 +428,12 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ - return self._process_request(inputs, pooling_params, request_id, + return self._process_request(prompt, pooling_params, request_id, lora_request, trace_headers) async def _process_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, @@ -468,7 +466,7 @@ async def _process_request( request_bytes = pickle.dumps( RPCProcessRequest( - inputs=inputs, + prompt=prompt, params=params, request_id=request_id, lora_request=lora_request, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index f4ca231570853..788c1573ae255 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -245,7 +245,7 @@ def _handle_process_request(self, request: RPCProcessRequest): try: self.engine.add_request( request_id=request_id, - inputs=request.inputs, + prompt=request.prompt, params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 70444faa670a2..d0bbeb357b506 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptInputs +from vllm.inputs.data import PromptType from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -35,19 +35,19 @@ def dead_error(self) -> BaseException: def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: - """Generates outputs for a request""" + """Generate outputs for a request.""" ... def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 248b070611cd2..5a98ef11ab94d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -10,7 +10,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -258,7 +258,7 @@ def generate( @overload def generate( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[PromptType, Sequence[PromptType]], /, # We may enable `inputs` keyword after removing the old API *, sampling_params: Optional[Union[SamplingParams, @@ -276,7 +276,7 @@ def generate( ) def generate( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -294,7 +294,9 @@ def generate( into a single list and pass it to this method. Args: - inputs: A list of inputs to generate completions for. + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -320,12 +322,13 @@ def generate( "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -340,7 +343,7 @@ def generate( sampling_params = SamplingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -396,9 +399,9 @@ def chat( conversation, mm_data = parse_chat_messages(messages, model_config, tokenizer) - prompt: Union[str, List[int]] + prompt_data: Union[str, List[int]] if isinstance(tokenizer, MistralTokenizer): - prompt = apply_mistral_chat_template( + prompt_data = apply_mistral_chat_template( tokenizer, messages=messages, chat_template=chat_template, @@ -406,7 +409,7 @@ def chat( tools=tools, ) else: - prompt = apply_hf_chat_template( + prompt_data = apply_hf_chat_template( tokenizer, conversation=conversation, chat_template=chat_template, @@ -414,17 +417,17 @@ def chat( tools=tools, ) - inputs: PromptInputs - if is_list_of(prompt, int): - inputs = TokensPrompt(prompt_token_ids=prompt) + prompt: PromptType + if is_list_of(prompt_data, int): + prompt = TokensPrompt(prompt_token_ids=prompt_data) else: - inputs = TextPrompt(prompt=prompt) + prompt = TextPrompt(prompt=prompt_data) if mm_data is not None: - inputs["multi_modal_data"] = mm_data + prompt["multi_modal_data"] = mm_data return self.generate( - inputs, + prompt, sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, @@ -494,7 +497,7 @@ def encode( @overload def encode( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[PromptType, Sequence[PromptType]], /, # We may enable `inputs` keyword after removing the old API *, pooling_params: Optional[Union[PoolingParams, @@ -512,7 +515,7 @@ def encode( ) def encode( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -528,9 +531,9 @@ def encode( into a single list and pass it to this method. Args: - inputs: The inputs to the LLM. You may pass a sequence of inputs for - batch inference. See :class:`~vllm.inputs.PromptInputs` - for more details about the format of each input. + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. @@ -553,19 +556,20 @@ def encode( ) if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -609,9 +613,9 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - inputs: List[PromptInputs] = [] + parsed_prompts: List[PromptType] = [] for i in range(num_requests): - item: PromptInputs + item: PromptType if prompts is not None: item = TextPrompt(prompt=prompts[i]) @@ -620,24 +624,24 @@ def _convert_v1_inputs( else: raise AssertionError - inputs.append(item) + parsed_prompts.append(item) - return inputs + return parsed_prompts def _validate_and_add_requests( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[PromptType, Sequence[PromptType]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, ) -> None: - if isinstance(inputs, (str, dict)): + if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. - inputs = [inputs] + prompts = [prompts] - num_requests = len(inputs) + num_requests = len(prompts) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -654,9 +658,9 @@ def _validate_and_add_requests( sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - for i, request_inputs in enumerate(inputs): + for i, prompt in enumerate(prompts): self._add_request( - request_inputs, + prompt, params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, @@ -665,7 +669,7 @@ def _validate_and_add_requests( def _add_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -673,7 +677,7 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request( request_id, - inputs, + prompt, params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 0b08e9691f915..ba1bef1ab3ecc 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,5 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -16,8 +16,8 @@ __all__ = [ "TextPrompt", "TokensPrompt", - "PromptInputs", - "SingletonPromptInputs", + "PromptType", + "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 75ab0c770155b..e072bb65714b9 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -33,7 +33,7 @@ class TokensPrompt(TypedDict): """ -SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] +SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ Set of possible schemas for a single LLM input: @@ -46,7 +46,7 @@ class TokensPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPromptInputs` may be employed +A prompt of type :class:`SingletonPromptType` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -55,12 +55,12 @@ class TokensPrompt(TypedDict): """ _T1_co = TypeVar("_T1_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) _T2_co = TypeVar("_T2_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) @@ -72,7 +72,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): The encoder and decoder prompts, respectively, may formatted according to any of the - :class:`SingletonPromptInputs` schemas, and are not + :class:`SingletonPromptType` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. @@ -81,7 +81,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): be used as an input to a decoder-only model, and that the `encoder_prompt` and `decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPromptInputs` instances. + :class:`SingletonPromptType` instances. """ encoder_prompt: _T1_co @@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] -PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] +PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -140,12 +140,8 @@ class EncoderDecoderLLMInputs(LLMInputs): """ -_T1 = TypeVar("_T1", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) -_T2 = TypeVar("_T2", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) +_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) +_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) def build_explicit_enc_dec_prompt( diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index ac9d355c64c80..e5fa1e4184277 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,7 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt) @@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict): def parse_singleton_prompt( - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: - if isinstance(inputs, str): - return ParsedStrPrompt(type="str", content=inputs) - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + if "prompt_token_ids" in prompt: return ParsedTokensPrompt(type="tokens", - content=inputs) # type: ignore - elif "prompt" in inputs: - return ParsedTextPrompt(type="text", content=inputs) + content=prompt) # type: ignore + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( - inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(inputs, dict) and "encoder_prompt" in inputs + prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(prompt, dict) and "encoder_prompt" in prompt def is_valid_encoder_decoder_llm_inputs( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index be2aa5f8cb7d0..1f1b048d37e9b 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -9,8 +9,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, - SingletonPromptInputs) +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, + SingletonPrompt) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: @@ -206,7 +206,7 @@ async def _tokenize_prompt_async( def _extract_prompt_components( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: @@ -216,7 +216,7 @@ def _extract_prompt_components( Arguments: * request_id - * inputs: single encoder or decoder input prompt + * prompt: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts Returns: @@ -226,24 +226,24 @@ def _extract_prompt_components( * multi_modal_data ''' - parsed = parse_singleton_prompt(inputs) + parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": - prompt = parsed["content"] + prompt_text = parsed["content"] prompt_token_ids = self._tokenize_prompt( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt = None + prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt = parsed["content"]["prompt"] + prompt_text = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) @@ -251,33 +251,33 @@ def _extract_prompt_components( else: assert_never(parsed) - return prompt, prompt_token_ids, multi_modal_data + return prompt_text, prompt_token_ids, multi_modal_data async def _extract_prompt_components_async( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" - parsed = parse_singleton_prompt(inputs) + parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": - prompt = parsed["content"] + prompt_text = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt = None + prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt = parsed["content"]["prompt"] + prompt_text = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) @@ -285,7 +285,7 @@ async def _extract_prompt_components_async( else: assert_never(parsed) - return prompt, prompt_token_ids, multi_modal_data + return prompt_text, prompt_token_ids, multi_modal_data def _build_enc_dec_llm_inputs( self, @@ -311,7 +311,7 @@ def _build_enc_dec_llm_inputs( def _process_encoder_decoder_prompt( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, ) -> EncoderDecoderLLMInputs: ''' @@ -339,7 +339,7 @@ def _process_encoder_decoder_prompt( Arguments: - * inputs: an input prompt + * prompt: an input prompt * request_id Returns: @@ -350,13 +350,13 @@ def _process_encoder_decoder_prompt( encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_comps = self._extract_prompt_components( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: decoder_comps = None, None, None else: decoder_comps = self._extract_prompt_components( @@ -365,7 +365,7 @@ def _process_encoder_decoder_prompt( ) else: encoder_comps = self._extract_prompt_components( - inputs, + prompt, request_id=request_id, ) @@ -375,20 +375,20 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_task = self._extract_prompt_components_async( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: encoder_comps = await encoder_task decoder_comps = None, None, None else: @@ -401,7 +401,7 @@ async def _process_encoder_decoder_prompt_async( encoder_task, decoder_task) else: encoder_comps = await self._extract_prompt_components_async( - inputs, + prompt, request_id=request_id, ) @@ -425,7 +425,7 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -436,7 +436,7 @@ def _process_decoder_only_prompt( Arguments: - * inputs: input prompt + * prompt: input prompt * request_id * lora_request * prompt_adapter_request @@ -447,7 +447,7 @@ def _process_decoder_only_prompt( ''' prompt_comps = self._extract_prompt_components( - inputs, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -459,14 +459,14 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -478,7 +478,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -488,17 +488,17 @@ def preprocess( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return self._process_decoder_only_prompt( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -506,7 +506,7 @@ def preprocess( async def preprocess_async( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -516,17 +516,17 @@ async def preprocess_async( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return await self._process_decoder_only_prompt_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, From 065a304ccb15c9f7f93ada3c205789ec0c100bcc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 20 Sep 2024 17:28:19 +0000 Subject: [PATCH 49/88] Remove unnecessary comments --- vllm/entrypoints/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5a98ef11ab94d..c7548ca4bcfbd 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -259,7 +259,7 @@ def generate( def generate( self, prompts: Union[PromptType, Sequence[PromptType]], - /, # We may enable `inputs` keyword after removing the old API + /, *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -498,7 +498,7 @@ def encode( def encode( self, prompts: Union[PromptType, Sequence[PromptType]], - /, # We may enable `inputs` keyword after removing the old API + /, *, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, From 741d4c145d3580598f2f792c459d5e67730c176a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 22 Sep 2024 14:47:51 +0000 Subject: [PATCH 50/88] Fix import --- vllm/model_executor/models/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 67cf7cf061788..42b10e0813b52 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from collections import UserDict -from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, - Union, overload) +from typing import (Callable, Dict, Iterable, List, Literal, Optional, Protocol, + Tuple, Union, overload) import torch import torch.nn as nn From 935002288d629900911a14584348d756730600be Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 22 Sep 2024 14:49:37 +0000 Subject: [PATCH 51/88] Format --- vllm/model_executor/models/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 42b10e0813b52..e58368878d958 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from collections import UserDict -from typing import (Callable, Dict, Iterable, List, Literal, Optional, Protocol, - Tuple, Union, overload) +from typing import (Callable, Dict, Iterable, List, Literal, Optional, + Protocol, Tuple, Union, overload) import torch import torch.nn as nn From 1de2b990f2429910b93167c7c4f027621dd28c92 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 28 Sep 2024 06:46:39 +0000 Subject: [PATCH 52/88] Add validation for embedding inputs --- vllm/inputs/preprocess.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 549c5ed1b1845..236e277c8d8e5 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,6 +1,7 @@ import asyncio from typing import List, Optional, Union +import torch from typing_extensions import assert_never from vllm.config import ModelConfig @@ -171,6 +172,13 @@ def _apply_prompt_adapter( return prompt_token_ids + def _validate_embed_inputs(self, prompt_embeds: torch.Tensor): + if len(prompt_embeds.shape) != 2: + raise ValueError("Embeddings should be a 2D input with shape " + "`(num_tokens, embed_dim)`") + + return prompt_embeds + def _tokenize_prompt( self, prompt: str, @@ -339,17 +347,18 @@ def _build_enc_dec_llm_inputs( elif encoder_inputs["type"] == "embed": raise NotImplementedError("Embedding inputs are not supported for " "encoder-decoder models yet") + encoder_inputs["prompt_embeds"] = self._validate_embed_inputs( + encoder_inputs["prompt_embeds"]) else: assert_never(encoder_inputs) if decoder_inputs["type"] == "token": - if "prompt_token_ids" in decoder_inputs: - decoder_inputs["prompt_token_ids"] = ( - self._prepare_decoder_input_ids_for_generation( - decoder_inputs["prompt_token_ids"], - force_bos=("multi_modal_data" not in encoder_inputs and - "multi_modal_data" not in decoder_inputs), - )) + decoder_inputs["prompt_token_ids"] = ( + self._prepare_decoder_input_ids_for_generation( + decoder_inputs["prompt_token_ids"], + force_bos=("multi_modal_data" not in encoder_inputs + and "multi_modal_data" not in decoder_inputs), + )) if "multi_modal_data" in decoder_inputs: raise ValueError("Multi-modal decoder inputs of encoder-" @@ -357,6 +366,8 @@ def _build_enc_dec_llm_inputs( elif decoder_inputs["type"] == "embed": raise NotImplementedError("Embedding inputs are not supported for " "encoder-decoder models yet") + decoder_inputs["prompt_embeds"] = self._validate_embed_inputs( + decoder_inputs["prompt_embeds"]) elif decoder_inputs["type"] == "empty": pass else: @@ -473,13 +484,13 @@ def _build_decoder_only_llm_inputs( prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: if prompt_inputs["type"] == "token": - if "prompt_token_ids" in prompt_inputs: - prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( - prompt_inputs["prompt_token_ids"], - prompt_adapter_request=prompt_adapter_request, - ) + prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( + prompt_inputs["prompt_token_ids"], + prompt_adapter_request=prompt_adapter_request, + ) elif prompt_inputs["type"] == "embed": - pass + prompt_inputs["prompt_embeds"] = self._validate_embed_inputs( + prompt_inputs["prompt_embeds"]) else: assert_never(prompt_inputs) From 9124115286808d4343ccd252c16d6fbdf84e9eb4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 28 Sep 2024 07:05:36 +0000 Subject: [PATCH 53/88] Update for mllama --- vllm/engine/llm_engine.py | 15 ++++----- vllm/inputs/__init__.py | 9 +++--- vllm/inputs/data.py | 9 +++--- vllm/inputs/parse.py | 11 +++---- vllm/inputs/preprocess.py | 8 ++--- vllm/inputs/registry.py | 16 ++++++---- vllm/model_executor/models/mllama.py | 48 ++++++++++++++++------------ vllm/sequence.py | 10 +++--- 8 files changed, 66 insertions(+), 60 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fd643c21cb8ec..5562834580390 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -28,9 +28,9 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderInputs, InputRegistry, PromptType) -from vllm.inputs.parse import is_valid_encoder_decoder_inputs +from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, + PromptType) +from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -637,7 +637,7 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], + processed_inputs: ProcessorInputs, params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -656,7 +656,7 @@ def _add_processed_request( lora_request, prompt_adapter_request) encoder_seq = None - if is_valid_encoder_decoder_inputs(processed_inputs): + if is_encoder_decoder_inputs(processed_inputs): encoder_seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) @@ -1876,9 +1876,8 @@ def _support_prompt_embeds(self) -> Tuple[bool, str]: return False, (f"Model {self.model_config.model} does not support " "input embeddings, but prompt_embeds was provided.") - def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderInputs]): - if is_valid_encoder_decoder_inputs(inputs): + def _validate_model_inputs(self, inputs: ProcessorInputs): + if is_encoder_decoder_inputs(inputs): prompt_ids = inputs["encoder"].get("prompt_token_ids") prompt_embeds = inputs["encoder"].get("prompt_embeds") else: diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 6dc903bb66a76..39bf4bf557d8a 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,9 +1,9 @@ from .data import (DecoderOnlyInputs, EmbedInputs, EmbedsPrompt, EmptyInputs, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, - PromptType, SingletonPrompt, TextPrompt, TokenInputs, - TokensPrompt, build_explicit_enc_dec_prompt, embed_inputs, - empty_inputs, to_enc_dec_tuple_list, token_inputs, - zip_enc_dec_prompts) + ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, + TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, + embed_inputs, empty_inputs, to_enc_dec_tuple_list, + token_inputs, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -30,6 +30,7 @@ "EmptyInputs", "empty_inputs", "EncoderDecoderInputs", + "ProcessorInputs", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 77cd6a3788024..28a74c7be174e 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -216,12 +216,11 @@ class EncoderDecoderInputs(TypedDict): decoder: Union[EmptyInputs, TokenInputs] """The inputs for the decoder portion.""" - encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] - """ - Optional multi-modal data to pass to the encoder model, - if the model supports it. - """ +ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] +""" +The inputs to :data:`vllm.inputs.InputProcessor`. +""" _T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 3d36c60a42478..14a971ed6a644 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -4,9 +4,9 @@ from vllm.utils import is_list_of -from .data import (DecoderOnlyInputs, EmbedsPrompt, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt, - TextPrompt, TokensPrompt) +from .data import (EmbedsPrompt, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, + SingletonPrompt, TextPrompt, TokensPrompt) class ParsedText(TypedDict): @@ -109,7 +109,6 @@ def is_explicit_encoder_decoder_prompt( return isinstance(prompt, dict) and "encoder_prompt" in prompt -def is_valid_encoder_decoder_inputs( - inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], -) -> TypeIs[EncoderDecoderInputs]: +def is_encoder_decoder_inputs( + inputs: ProcessorInputs, ) -> TypeIs[EncoderDecoderInputs]: return "encoder" in inputs and "decoder" in inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 236e277c8d8e5..9654f15135bb3 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -12,8 +12,8 @@ from vllm.utils import print_warning_once from .data import (DecoderOnlyInputs, EmptyInputs, EncoderDecoderInputs, - PromptType, SingletonPrompt, embed_inputs, empty_inputs, - token_inputs) + ProcessorInputs, PromptType, SingletonPrompt, embed_inputs, + empty_inputs, token_inputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) @@ -555,7 +555,7 @@ def preprocess( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: + ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of @@ -583,7 +583,7 @@ async def preprocess_async( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: + ) -> ProcessorInputs: """Async version of :meth:`preprocess`.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index fe0fb2a363dbc..86e0097c1b561 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -11,7 +11,7 @@ from vllm.logger import init_logger from vllm.utils import get_allowed_kwarg_only_overrides, print_warning_once -from .data import DecoderOnlyInputs +from .data import ProcessorInputs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -99,7 +99,7 @@ def __getitem__(self, key: str) -> int: raise KeyError(msg) from exc -InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs] +InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs] """Preprocess the inputs to the model.""" @@ -252,9 +252,8 @@ def dummy_data_for_profiling( return seq_data, mm_data - def _default_input_processor( - self, ctx: InputContext, - inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: + def _default_input_processor(self, ctx: InputContext, + inputs: ProcessorInputs) -> ProcessorInputs: """The default input processor is a no-op.""" return inputs @@ -286,8 +285,11 @@ def _get_model_input_processor(self, model_cls: Type[nn.Module]): return self._input_processors_by_model_type \ .get(model_cls, self._default_input_processor) - def process_input(self, model_config: "ModelConfig", - inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: + def process_input( + self, + model_config: "ModelConfig", + inputs: ProcessorInputs, + ) -> ProcessorInputs: """ Apply an input processor to an instance of model inputs. diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 45d6ad3c0efa5..115b83f5988cf 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -33,7 +33,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, InputContext, ProcessorInputs +from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -48,6 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.utils import is_list_of from .clip import CLIPMLP from .interfaces import SupportsMultiModal @@ -72,31 +74,36 @@ class MllamaImagePixelInputs(TypedDict): # TODO: support LlamaImageEmbeddingInputs -def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_mllama(ctx: InputContext, inputs: ProcessorInputs): + assert is_encoder_decoder_inputs(inputs) + enc_inputs = inputs["encoder"] + dec_inputs = inputs["decoder"] + # move encoder_prompt to prompt - if llm_inputs.get("prompt") is None: - llm_inputs["prompt"] = llm_inputs["encoder_prompt"] - llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + if dec_inputs.get("prompt") is None: + dec_inputs["prompt"] = enc_inputs["prompt"] + dec_inputs["prompt_token_ids"] = enc_inputs["prompt_token_ids"] # process multi-modal data - assert "decoder_multi_modal_data" not in llm_inputs, \ - "multi-modal data should be put in encoder message of mllama" - multi_modal_data = llm_inputs.get("encoder_multi_modal_data") + multi_modal_data = enc_inputs.get("multi_modal_data") + image_data = (None if multi_modal_data is None else + multi_modal_data.get("image")) - if multi_modal_data is None or "image" not in multi_modal_data \ - or multi_modal_data["image"] is None: + if image_data is None: # text-only - llm_inputs["encoder_prompt"] = "" - llm_inputs["encoder_prompt_token_ids"] = [] - llm_inputs["encoder_multi_modal_data"] = {} - return llm_inputs + enc_inputs["prompt"] = "" + enc_inputs["prompt_token_ids"] = [] + enc_inputs["multi_modal_data"] = {} + return inputs # get num_tiles - if isinstance(multi_modal_data['image'], Image.Image): - multi_modal_data['image'] = [multi_modal_data['image']] + if isinstance(image_data, Image.Image): + image_data = [image_data] + assert is_list_of(image_data, Image.Image) + hf_config = ctx.model_config.hf_config num_tiles = 0 - for image in multi_modal_data["image"]: + for image in image_data: width, height = image.size tile_size = hf_config.vision_config.image_size canvas_height, canvas_width = get_optimal_tiled_canvas( @@ -114,11 +121,10 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk - llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens - llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID - ] * num_tokens + enc_inputs["prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens + enc_inputs["prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens - return llm_inputs + return inputs def get_max_mllama_image_tokens(ctx: InputContext) -> int: diff --git a/vllm/sequence.py b/vllm/sequence.py index 1f0d053aa03d2..027e6dbd1ec50 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -13,7 +13,7 @@ import msgspec import torch -from vllm.inputs.parse import is_valid_encoder_decoder_inputs +from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -439,7 +439,7 @@ def n_blocks(self) -> int: def prompt(self) -> Optional[str]: # Select decoder or encoder input prompt str, as appropriate inputs = self.inputs - if is_valid_encoder_decoder_inputs(inputs): + if is_encoder_decoder_inputs(inputs): prompt = inputs["encoder"].get("prompt") else: prompt = cast(Optional[str], inputs.get("prompt")) @@ -450,7 +450,7 @@ def prompt(self) -> Optional[str]: def prompt_token_ids(self) -> List[int]: # Select decoder or encoder input prompt token ids, as appropriate inputs = self.inputs - if is_valid_encoder_decoder_inputs(inputs): + if is_encoder_decoder_inputs(inputs): prompt_token_ids = inputs["encoder"].get("prompt_token_ids") else: prompt_token_ids = cast(Optional[List[int]], @@ -462,7 +462,7 @@ def prompt_token_ids(self) -> List[int]: def prompt_embeds(self) -> Optional[torch.Tensor]: # Select decoder or encoder input prompt embeds, as appropriate inputs = self.inputs - if is_valid_encoder_decoder_inputs(inputs): + if is_encoder_decoder_inputs(inputs): prompt_embeds = inputs["encoder"].get("prompt_embeds") else: prompt_embeds = cast(Optional[torch.Tensor], @@ -473,7 +473,7 @@ def prompt_embeds(self) -> Optional[torch.Tensor]: @cached_property def multi_modal_data(self) -> "MultiModalDataDict": inputs = self.inputs - if is_valid_encoder_decoder_inputs(inputs): + if is_encoder_decoder_inputs(inputs): multi_modal_data = inputs["encoder"].get("multi_modal_data") else: multi_modal_data = cast(Optional["MultiModalDataDict"], From 6c366ebbdf0206f02f3d867cc814a4b15ece7d10 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 28 Sep 2024 07:10:00 +0000 Subject: [PATCH 54/88] format --- vllm/inputs/parse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 14a971ed6a644..18a57885f0ce0 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -110,5 +110,5 @@ def is_explicit_encoder_decoder_prompt( def is_encoder_decoder_inputs( - inputs: ProcessorInputs, ) -> TypeIs[EncoderDecoderInputs]: + inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: return "encoder" in inputs and "decoder" in inputs From ad6c364628c0345ae402e498e81efe2b564aa595 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Oct 2024 15:32:05 +0000 Subject: [PATCH 55/88] Fix failing tests --- tests/core/utils.py | 54 ++++++++----------- .../vision_language/test_phi3v.py | 8 +-- tests/multimodal/test_processor_kwargs.py | 10 ++-- vllm/model_executor/models/llava_onevision.py | 39 +++++++------- vllm/model_executor/models/phi3v.py | 10 ++-- 5 files changed, 56 insertions(+), 65 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index d7459b90ab5f9..7fd454af12432 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -4,7 +4,7 @@ from typing import Tuple from vllm import SamplingParams -from vllm.inputs import EncoderDecoderInputs +from vllm.inputs import EncoderDecoderInputs, token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, Sequence, SequenceGroup @@ -27,10 +27,7 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), - inputs={ - "prompt": prompt_str, - "prompt_token_ids": prompt_tokens, - }, + inputs=token_inputs(prompt_tokens, prompt=prompt_str), block_size=block_size) seq_group = SequenceGroup(request_id=request_id, seqs=[prompt], @@ -64,24 +61,18 @@ def create_dummy_prompt_encoder_decoder( encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) inputs: EncoderDecoderInputs = { - "decoder": { - "type": "token", - "prompt": decoder_prompt_str, - "prompt_token_ids": decoder_prompt_tokens, - }, - "encoder": { - "type": "token", - "prompt": encoder_prompt_str, - "prompt_token_ids": encoder_prompt_tokens, - }, + "decoder": token_inputs(decoder_prompt_tokens, + prompt=decoder_prompt_str), + "encoder": token_inputs(encoder_prompt_tokens, + prompt=encoder_prompt_str), } decoder_prompt = Sequence(int(request_id), - inputs=inputs, + inputs=inputs["decoder"], block_size=block_size) encoder_prompt = Sequence(int(request_id), - inputs=inputs, + inputs=inputs["encoder"], block_size=block_size) seq_group = SequenceGroup(request_id=request_id, @@ -114,7 +105,7 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - inputs={"prompt_token_ids": prompt_token_ids}, + inputs=token_inputs(prompt_token_ids), block_size=16, ) @@ -149,21 +140,19 @@ def create_seq_group_encoder_decoder( prompt_token_ids = [0] * seq_prompt_len - inputs = { - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "encoder_prompt": "", - "encoder_prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, + inputs: EncoderDecoderInputs = { + "decoder": token_inputs(prompt_token_ids), + "encoder": token_inputs(prompt_token_ids), } seqs = [] for seq_id_offset, output_len in enumerate(seq_output_lens): # Construct decoder input sequences - seq = Sequence(seq_id=seq_id_start + seq_id_offset, - inputs=inputs, - block_size=16, - from_decoder_prompt=True) + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + inputs=inputs["decoder"], + block_size=16, + ) for i in range(output_len): seq.append_token_id( @@ -173,10 +162,11 @@ def create_seq_group_encoder_decoder( seqs.append(seq) # Encoder input sequence - encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens), - inputs=inputs, - block_size=16, - from_decoder_prompt=False) + encoder_seq = Sequence( + seq_id=seq_id_start + len(seq_output_lens), + inputs=inputs["encoder"], + block_size=16, + ) return SequenceGroup(request_id=request_id, seqs=seqs, diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index 00c1b9975ef35..c972ea4d0c677 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -6,7 +6,7 @@ import torch from transformers import AutoImageProcessor, AutoTokenizer -from vllm.inputs import InputContext, LLMInputs +from vllm.inputs import InputContext, token_inputs from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID from vllm.multimodal import MultiModalRegistry from vllm.multimodal.utils import rescale_image_size @@ -393,9 +393,9 @@ def test_input_processor_override(input_processor_for_phi3v: Callable, prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" images = [image_assets[0].pil_image] * num_imgs - llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt), - prompt=prompt, - multi_modal_data={"image": images}) + llm_inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) proc_llm_inputs = input_processor_for_phi3v( ctx=ctx, diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 5529ccd4fa570..0668948e1ad60 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -5,7 +5,7 @@ import pytest import torch -from vllm.inputs import InputContext, LLMInputs +from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs from vllm.inputs.registry import InputRegistry from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -31,7 +31,7 @@ def use_processor_mock(): """Patches the internal model input processor with an override callable.""" def custom_processor(ctx: InputContext, - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, num_crops=DEFAULT_NUM_CROPS): # For testing purposes, we don't worry about the llm inputs / return @@ -84,7 +84,7 @@ def test_default_processor_is_a_noop(): dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID) processor = dummy_registry.create_input_processor(ctx.model_config) - proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") + proc_inputs = token_inputs(prompt_token_ids=[], prompt="") proc_outputs = processor(inputs=proc_inputs) assert proc_inputs is proc_outputs @@ -104,7 +104,7 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): mm_processor_kwargs=mm_processor_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) - num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) + num_crops_val = processor(token_inputs(prompt_token_ids=[], prompt="")) assert num_crops_val == expected_num_crops @@ -128,7 +128,7 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, mm_processor_kwargs=mm_processor_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) - num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) + num_crops_val = processor(token_inputs(prompt_token_ids=[], prompt="")) assert num_crops_val == DEFAULT_NUM_CROPS diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 9099d4f88222d..41db69a2c3173 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -14,7 +14,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization.base_config import ( @@ -253,10 +254,10 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, def input_processor_when_multimodal_input_image(ctx: InputContext, - llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(LlavaOnevisionConfig) @@ -291,7 +292,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, return input_processor_for_clip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) @@ -299,7 +300,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, return input_processor_for_siglip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) @@ -309,10 +310,10 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, def input_processor_when_multimodal_input_video(ctx: InputContext, - llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "video" not in multi_modal_data: - return llm_inputs + return inputs video_data = multi_modal_data["video"] model_config = ctx.model_config @@ -327,15 +328,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=hf_config.video_token_index, repeat_count=video_feature_size, ) - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( @@ -346,15 +347,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, def input_processor_for_llava_onevision(ctx: InputContext, - llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or ("video" not in multi_modal_data and "image" not in multi_modal_data): - return llm_inputs + return inputs if "image" in multi_modal_data: - return input_processor_when_multimodal_input_image(ctx, llm_inputs) + return input_processor_when_multimodal_input_image(ctx, inputs) if "video" in multi_modal_data: - return input_processor_when_multimodal_input_video(ctx, llm_inputs) + return input_processor_when_multimodal_input_video(ctx, inputs) msg = "Unsupported multi data type" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 0026c76d1e1e6..e8efa84e7ba96 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -27,7 +27,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -498,10 +499,9 @@ def input_processor_for_phi3v(ctx: InputContext, new_token_ids.append(token_id) # NOTE: Create a defensive copy of the original inputs - llm_inputs = DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) - return llm_inputs + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) @MULTIMODAL_REGISTRY.register_image_input_mapper() From 13bbd02a8d2cccd7f87093d14045596b2369cefc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 5 Oct 2024 02:50:41 +0000 Subject: [PATCH 56/88] format --- vllm/model_executor/models/arctic.py | 3 +-- vllm/model_executor/models/baichuan.py | 3 +-- vllm/model_executor/models/bloom.py | 3 +-- vllm/model_executor/models/chatglm.py | 3 +-- vllm/model_executor/models/commandr.py | 3 +-- vllm/model_executor/models/dbrx.py | 3 +-- vllm/model_executor/models/deepseek.py | 3 +-- vllm/model_executor/models/falcon.py | 3 +-- vllm/model_executor/models/gemma2.py | 3 +-- vllm/model_executor/models/gpt_bigcode.py | 3 +-- vllm/model_executor/models/gpt_j.py | 3 +-- vllm/model_executor/models/gpt_neox.py | 3 +-- vllm/model_executor/models/mixtral.py | 3 +-- vllm/model_executor/models/mixtral_quant.py | 3 +-- vllm/model_executor/models/mpt.py | 3 +-- vllm/model_executor/models/olmo.py | 3 +-- vllm/model_executor/models/orion.py | 3 +-- vllm/model_executor/models/persimmon.py | 3 +-- vllm/model_executor/models/phi.py | 3 +-- vllm/model_executor/models/phi3_small.py | 3 +-- vllm/model_executor/models/phimoe.py | 3 +-- vllm/model_executor/models/stablelm.py | 3 +-- vllm/model_executor/models/starcoder2.py | 3 +-- vllm/model_executor/models/xverse.py | 3 +-- vllm/worker/model_runner.py | 4 ++-- 25 files changed, 26 insertions(+), 50 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 1b0741f36ccfd..e4e78c03cc0dc 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -398,8 +398,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 5b650ddd0b14b..7bb3a4f798a30 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -289,8 +289,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 2970cf706305b..3fb8b0f98eb11 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -261,8 +261,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.word_embeddings, + hidden_states = get_inputs_embeds(input_ids, self.word_embeddings, inputs_embeds, inputs_embeds_masks) hidden_states = self.word_embeddings_layernorm(hidden_states) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index a031a40bb8dba..3734db446af22 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -323,8 +323,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embedding, + hidden_states = get_inputs_embeds(input_ids, self.embedding, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 1758e9c8f8da5..c29a05ed069fa 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -289,8 +289,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 30fffe61e5ed9..8625a76e2fca9 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -332,8 +332,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.wte, + hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 2bd89e207d6cd..ec6d2eb5403c6 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -365,8 +365,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index a5e5a455838a9..ae736ad6df61d 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -374,8 +374,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.word_embeddings, + hidden_states = get_inputs_embeds(input_ids, self.word_embeddings, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index a3ecd0bb67783..4542bc4faaf91 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -282,8 +282,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) hidden_states *= self.normalizer diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index f922a143abe0c..4945faafb3573 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -230,8 +230,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.wte, + hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) position_embeds = self.wpe(position_ids) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index e11b829a566b0..d8b323c5d24c7 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -211,8 +211,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.wte, + hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index cfdc88d303cf9..00096b48ad16b 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -225,8 +225,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_in, + hidden_states = get_inputs_embeds(input_ids, self.embed_in, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a42d8df3a222c..b6d9704b9e943 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -291,8 +291,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 7730c67adc7e4..afacff7abf6fd 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -330,8 +330,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 22015e5af6736..0397b1bcba7bd 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -247,8 +247,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.wte, + hidden_states = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 5a1efb8bdbc52..9b855b72fda69 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -261,8 +261,7 @@ def forward( if get_pp_group().is_first_rank: # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 7d06911182620..0336d920ec4eb 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -247,8 +247,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index f9678b31eed91..d84283b2c8c20 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -243,8 +243,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index eafcc5d9c6573..8bff3034a3572 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -225,8 +225,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 4655ad3c8d998..1870f0e31593e 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -341,8 +341,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) if (self.mup_embedding_multiplier is not None diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 770c463b28771..1189398edad8f 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -475,8 +475,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 16b51c8a2d01c..dc7bf52aff4e0 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -228,8 +228,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 5268aa253854e..32aa0a5e9acdc 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -230,8 +230,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) else: diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 52f42c38418f5..eea9689a52368 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -264,8 +264,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, - self.embed_tokens, + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) residual = None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 37e134a838dfe..4b61e3bfa51ca 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1697,8 +1697,8 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), + **MultiModalInputs.as_kwargs( + multi_modal_kwargs, device=self.device), **seqlen_agnostic_kwargs) if self.model_supports_input_embeds: model_params.update( From 4000b90c5df8a1ddfa0c279f0f7f478c2b486ed7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 7 Oct 2024 16:46:16 +0000 Subject: [PATCH 57/88] Improve type annotations --- vllm/model_executor/models/__init__.py | 8 ++-- vllm/model_executor/models/interfaces.py | 47 ++++++++++++++++++++---- 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 59a6c27258d95..72fc18a780a50 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,6 +1,7 @@ -from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, - SupportsPP, has_inner_state, supports_input_embeds, - supports_lora, supports_multimodal, supports_pp) +from .interfaces import (HasInnerState, SupportsInputEmbeds, SupportsLoRA, + SupportsMultiModal, SupportsPP, has_inner_state, + supports_input_embeds, supports_lora, + supports_multimodal, supports_pp) from .interfaces_base import (VllmModelForEmbedding, VllmModelForTextGeneration, is_embedding_model, is_text_generation_model) @@ -14,6 +15,7 @@ "is_text_generation_model", "HasInnerState", "has_inner_state", + "SupportsInputEmbeds", "supports_input_embeds", "SupportsLoRA", "supports_lora", diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 1d6d869b29d6f..e32a1a305521e 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -2,7 +2,6 @@ Protocol, Type, Union, overload, runtime_checkable) import torch -import torch.nn as nn from typing_extensions import TypeIs from vllm.logger import init_logger @@ -15,6 +14,45 @@ logger = init_logger(__name__) +@runtime_checkable +class SupportsInputEmbeds(Protocol): + """The interface required to support embedding inputs.""" + + def forward( + self, + *, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + ... + + +@overload +def supports_input_embeds( + model: Type[object]) -> TypeIs[Type[SupportsInputEmbeds]]: + ... + + +@overload +def supports_input_embeds(model: object) -> TypeIs[SupportsInputEmbeds]: + ... + + +def supports_input_embeds( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsInputEmbeds]], TypeIs[SupportsInputEmbeds]]: + """Check if the model supports input_embeds and input_embeds_masks.""" + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + required_kws = ("inputs_embeds", "inputs_embeds_masks") + missing_kws = tuple(kw for kw in required_kws + if not supports_kw(model_forward, kw)) + + return len(missing_kws) == 0 + + @runtime_checkable class SupportsMultiModal(Protocol): """The interface required for all multi-modal models.""" @@ -308,10 +346,3 @@ def has_inner_state( return isinstance(model, _HasInnerStateType) return isinstance(model, HasInnerState) - - -def supports_input_embeds(model: nn.Module) -> bool: - """Check if the model supports input_embeds and input_embeds_masks.""" - model_forward = model.forward - return (supports_kw(model_forward, "inputs_embeds") - and supports_kw(model_forward, "inputs_embeds_masks")) From 1bafe1b459182364aaf978a28dd361d948c5f5b4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 16 Oct 2024 12:48:51 +0000 Subject: [PATCH 58/88] Update --- vllm/inputs/registry.py | 98 +++++++++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 38 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 5cb041d110004..35b64f9e0c19a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,14 +1,15 @@ import functools from collections import UserDict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, - Optional, Protocol, Tuple, Type, cast) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, + Protocol, Tuple, Type, overload) from torch import nn from transformers import PretrainedConfig -from typing_extensions import TypeVar +from typing_extensions import TypeVar, assert_never -from vllm.inputs import SingletonInputs +from vllm.inputs import (EmbedInputs, EmptyInputs, EncoderDecoderInputs, + SingletonInputs, TokenInputs) from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.logger import init_logger from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, @@ -280,11 +281,52 @@ def _get_model_input_processor(self, model_cls: Type[nn.Module]): return self._input_processors_by_model_type \ .get(model_cls, self._default_input_processor) - def process_input( - self, - model_config: "ModelConfig", - inputs: ProcessorInputs, - ) -> ProcessorInputs: + @overload + def _process_singleton_input(self, model_config: "ModelConfig", + inputs: TokenInputs) -> TokenInputs: + ... + + @overload + def _process_singleton_input(self, model_config: "ModelConfig", + inputs: EmbedInputs) -> EmbedInputs: + ... + + @overload + def _process_singleton_input(self, model_config: "ModelConfig", + inputs: EmptyInputs) -> EmptyInputs: + ... + + def _process_singleton_input(self, model_config: "ModelConfig", + inputs: SingletonInputs) -> SingletonInputs: + if inputs["type"] == "empty": + return inputs + + if inputs["type"] == "embed": + return inputs + + if inputs["type"] == "token": + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + processor = self._get_model_input_processor(model_cls) + + # Handle multimodal processor kwargs with priority: + # Inference kwargs -> Init kwargs -> {} + # If it's empty, it'll fall back to the default kwarg values + mm_processor_kwargs = resolve_mm_processor_kwargs( + model_config.mm_processor_kwargs, + inputs.pop("mm_processor_kwargs"), + processor, + ) + + return processor(InputContext(model_config), inputs, + **mm_processor_kwargs) + + assert_never(inputs) + + def process_input(self, model_config: "ModelConfig", + inputs: ProcessorInputs) -> ProcessorInputs: """ Apply an input processor to an instance of model inputs. @@ -293,35 +335,15 @@ def process_input( See also: :ref:`input_processing_pipeline` """ - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - processor = self._get_model_input_processor(model_cls) - - inputs_mm_processor_kwargs: Dict[str, Any] = {} - singleton_inputs = cast( - List[SingletonInputs], - ([inputs["encoder"], inputs["decoder"]] - if is_encoder_decoder_inputs(inputs) else [inputs]), - ) - for singleton_input in singleton_inputs: - if singleton_input["type"] == "token": - kw = singleton_input.pop("mm_processor_kwargs") - if kw is not None: - inputs_mm_processor_kwargs.update(kw) - - # Handle multimodal processor kwargs with priority: - # Inference kwargs -> Init kwargs -> {} - # If it's empty, it'll fall back to the default kwarg values - inputs_mm_processor_kwargs = resolve_mm_processor_kwargs( - model_config.mm_processor_kwargs, - inputs_mm_processor_kwargs, - processor, - ) - - return processor(InputContext(model_config), inputs, - **inputs_mm_processor_kwargs) + if is_encoder_decoder_inputs(inputs): + return EncoderDecoderInputs( + encoder=self._process_singleton_input(model_config, + inputs["encoder"]), + decoder=self._process_singleton_input(model_config, + inputs["decoder"]), + ) + + return self._process_singleton_input(model_config, inputs) def create_input_processor(self, model_config: "ModelConfig"): """ From 2da9eea63814e07261ff528ec1b515e416c7a9ae Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 22 Oct 2024 17:15:56 +0000 Subject: [PATCH 59/88] Fix KeyError; debug --- vllm/inputs/registry.py | 2 +- vllm/model_executor/models/utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 35b64f9e0c19a..733fb4d11400a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -316,7 +316,7 @@ def _process_singleton_input(self, model_config: "ModelConfig", # If it's empty, it'll fall back to the default kwarg values mm_processor_kwargs = resolve_mm_processor_kwargs( model_config.mm_processor_kwargs, - inputs.pop("mm_processor_kwargs"), + inputs.get("mm_processor_kwargs"), processor, ) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 4beba515753e5..0a59684e9cc86 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -477,6 +477,7 @@ def get_inputs_embeds( hidden_states = embeddings_module(input_ids) + assert hidden_states is not None return hidden_states From 0c872b3fd92f07376f8ff08ad99550eb8b7eea11 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 04:59:22 +0000 Subject: [PATCH 60/88] Make encoder-decoder inputs a composed structure --- tests/core/utils.py | 57 +++---- tests/test_cache_block_hashing.py | 7 +- tests/tokenization/test_detokenize.py | 6 +- vllm/engine/llm_engine.py | 42 ++--- vllm/inputs/__init__.py | 12 +- vllm/inputs/data.py | 48 +++--- vllm/inputs/parse.py | 11 +- vllm/inputs/preprocess.py | 237 +++++++++++++------------- vllm/inputs/registry.py | 10 +- vllm/model_executor/models/mllama.py | 51 +++--- vllm/sequence.py | 102 ++++------- 11 files changed, 275 insertions(+), 308 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index a95a573db7cd3..cd0caa4704e11 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -4,6 +4,7 @@ from typing import Tuple from vllm import SamplingParams +from vllm.inputs import EncoderDecoderInputs, token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, Sequence, SequenceGroup @@ -27,10 +28,7 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), - inputs={ - "prompt": prompt_str, - "prompt_token_ids": prompt_tokens, - }, + inputs=token_inputs(prompt_tokens, prompt=prompt_str), block_size=block_size) seq_group = SequenceGroup(request_id=request_id, seqs=[prompt], @@ -63,23 +61,21 @@ def create_dummy_prompt_encoder_decoder( encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - inputs = { - "prompt": decoder_prompt_str, - "prompt_token_ids": decoder_prompt_tokens, - "encoder_prompt": encoder_prompt_str, - "encoder_prompt_token_ids": encoder_prompt_tokens, - "multi_modal_data": None, + inputs: EncoderDecoderInputs = { + "decoder": token_inputs(decoder_prompt_tokens, + prompt=decoder_prompt_str), + "encoder": token_inputs(encoder_prompt_tokens, + prompt=encoder_prompt_str), } decoder_prompt = Sequence(int(request_id), - inputs=inputs, - block_size=block_size, - from_decoder_prompt=True) + inputs=inputs["decoder"], + block_size=block_size) encoder_prompt = Sequence(int(request_id), - inputs=inputs, - block_size=block_size, - from_decoder_prompt=False) + inputs=inputs["encoder"], + block_size=block_size) + seq_group = SequenceGroup(request_id=request_id, seqs=[decoder_prompt], sampling_params=SamplingParams(best_of=best_of), @@ -108,7 +104,7 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - inputs={"prompt_token_ids": prompt_token_ids}, + inputs=token_inputs(prompt_token_ids), block_size=16, ) @@ -143,21 +139,19 @@ def create_seq_group_encoder_decoder( prompt_token_ids = [0] * seq_prompt_len - inputs = { - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "encoder_prompt": "", - "encoder_prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, + inputs: EncoderDecoderInputs = { + "decoder": token_inputs(prompt_token_ids), + "encoder": token_inputs(prompt_token_ids), } seqs = [] for seq_id_offset, output_len in enumerate(seq_output_lens): # Construct decoder input sequences - seq = Sequence(seq_id=seq_id_start + seq_id_offset, - inputs=inputs, - block_size=16, - from_decoder_prompt=True) + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + inputs=inputs["decoder"], + block_size=16, + ) for i in range(output_len): seq.append_token_id( @@ -167,10 +161,11 @@ def create_seq_group_encoder_decoder( seqs.append(seq) # Encoder input sequence - encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens), - inputs=inputs, - block_size=16, - from_decoder_prompt=False) + encoder_seq = Sequence( + seq_id=seq_id_start + len(seq_output_lens), + inputs=inputs["encoder"], + block_size=16, + ) return SequenceGroup(request_id=request_id, seqs=seqs, diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 3576a4834ebc3..e8f8499aa88ca 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -6,6 +6,7 @@ import pytest +from vllm.inputs import token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import Sequence from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -70,10 +71,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) seq = Sequence(seq_id, - inputs={ - "prompt": prompt, - "prompt_token_ids": prompt_token_ids, - }, + inputs=token_inputs(prompt_token_ids, + prompt=prompt), block_size=block_size, eos_token_id=tokenizer.tokenizer.eos_token_id, lora_request=lora_request) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index f4551ed42efb8..921ce6b097301 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -3,6 +3,7 @@ import pytest from transformers import AutoTokenizer +from vllm.inputs import token_inputs from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.transformers_utils.detokenizer import (Detokenizer, detokenize_incrementally) @@ -123,10 +124,7 @@ def create_sequence(prompt_token_ids=None): prompt_token_ids = prompt_token_ids or [1] return Sequence( seq_id=0, - inputs={ - "prompt": "", - "prompt_token_ids": prompt_token_ids, - }, + inputs=token_inputs(prompt_token_ids, prompt=""), block_size=16, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 25c4e76d9b159..2a302b058f6b5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -30,8 +30,9 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderInputs, InputRegistry, PromptType) +from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, + PromptType) +from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -639,7 +640,7 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], + processed_inputs: ProcessorInputs, params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -656,18 +657,19 @@ def _add_processed_request( seq_id = next(self.seq_counter) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + if is_encoder_decoder_inputs(processed_inputs): + decoder_inputs = processed_inputs["decoder"] + encoder_inputs = processed_inputs["encoder"] + else: + decoder_inputs = processed_inputs + encoder_inputs = None + + seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) - encoder_seq = None - if 'encoder_prompt_token_ids' in processed_inputs: - encoder_seq = Sequence(seq_id, - processed_inputs, - block_size, - eos_token_id, - lora_request, - prompt_adapter_request, - from_decoder_prompt=False) + encoder_seq = (None if encoder_inputs is None else Sequence( + seq_id, encoder_inputs, block_size, eos_token_id, lora_request, + prompt_adapter_request)) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -1909,16 +1911,16 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: def is_encoder_decoder_model(self): return self.input_preprocessor.is_encoder_decoder_model() - def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderInputs]): - if self.model_config.is_multimodal_model: + def _validate_model_inputs(self, inputs: ProcessorInputs): + if is_encoder_decoder_inputs(inputs): # For encoder-decoder multimodal models, the max_prompt_len # restricts the decoder prompt length - prompt_ids = inputs.get("prompt_token_ids") - elif self.is_encoder_decoder_model(): - prompt_ids = inputs.get("encoder_prompt_token_ids") + prompt_inputs = inputs["decoder" if self.model_config. + is_multimodal_model else "encoder"] else: - prompt_ids = inputs.get("prompt_token_ids") + prompt_inputs = inputs + + prompt_ids = prompt_inputs.get("prompt_token_ids") if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 7b73922ddd2c5..57793349780c2 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,8 +1,8 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs, - SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - token_inputs, zip_enc_dec_prompts) + ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, + SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -22,9 +22,9 @@ "ExplicitEncoderDecoderPrompt", "TokenInputs", "token_inputs", - "SingletonInputs", - "DecoderOnlyInputs", "EncoderDecoderInputs", + "ProcessorInputs", + "SingletonInputs", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 9a094191eda38..8f91d3867047d 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,4 +1,4 @@ -from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal, Optional, Tuple, Union, cast) from typing_extensions import NotRequired, TypedDict, TypeVar @@ -122,21 +122,25 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): class TokenInputs(TypedDict): """Represents token-based inputs.""" + + type: Literal["token"] + """The type of inputs.""" + prompt_token_ids: List[int] """The token IDs of the prompt.""" - prompt: NotRequired[Optional[str]] + prompt: NotRequired[str] """ The original prompt text corresponding to the token IDs, if available. """ - multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, if the model supports it. """ - mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] + mm_processor_kwargs: NotRequired[Dict[str, Any]] """ Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities @@ -152,7 +156,7 @@ def token_inputs( mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" - inputs = TokenInputs(prompt_token_ids=prompt_token_ids) + inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) if prompt is not None: inputs["prompt"] = prompt @@ -164,12 +168,6 @@ def token_inputs( return inputs -SingletonInputs = TokenInputs -""" -A processed :class:`SingletonPrompt` which can be passed to -:class:`vllm.sequence.Sequence`. -""" - DecoderOnlyInputs = TokenInputs """ The inputs in :class:`~vllm.LLMEngine` before they are @@ -178,28 +176,30 @@ def token_inputs( """ -class EncoderDecoderInputs(TokenInputs): +class EncoderDecoderInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. This specifies the required data for encoder-decoder models. """ - encoder_prompt_token_ids: List[int] - """The token IDs of the encoder prompt.""" + encoder: TokenInputs + """The inputs for the encoder portion.""" - encoder_prompt: NotRequired[Optional[str]] - """ - The original encoder prompt text corresponding to the token IDs, if - available. - """ + decoder: TokenInputs + """The inputs for the decoder portion.""" - encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] - """ - Optional multi-modal data to pass to the encoder model, - if the model supports it. - """ +SingletonInputs = TokenInputs +""" +A processed :class:`SingletonPrompt` which can be passed to +:class:`vllm.sequence.Sequence`. +""" + +ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] +""" +The inputs to :data:`vllm.inputs.InputProcessor`. +""" _T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e79d2c813bb4f..b11a151c4a585 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -4,9 +4,9 @@ from vllm.utils import is_list_of -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt, - TextPrompt, TokensPrompt) +from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, + ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, + TokensPrompt) class ParsedText(TypedDict): @@ -104,6 +104,5 @@ def is_explicit_encoder_decoder_prompt( def is_encoder_decoder_inputs( - inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], -) -> TypeIs[EncoderDecoderInputs]: - return "encoder_prompt_token_ids" in inputs + inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: + return "encoder" in inputs and "decoder" in inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 82ce7d392b719..9681c7dc548f6 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional from typing_extensions import assert_never @@ -10,22 +10,12 @@ from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.utils import print_warning_once -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType, - SingletonPrompt) +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, + PromptType, SingletonPrompt, TokenInputs, token_inputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt -if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict - logger = init_logger(__name__) -PromptComponents = Tuple[Optional[str], List[int], - Optional["MultiModalDataDict"], Optional[Dict[str, - Any]]] -DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional["MultiModalDataDict"], - Optional[Dict[str, Any]]] - class InputPreprocessor: @@ -115,7 +105,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: "default" decoder prompt be . However, it is possible that in the future - other models may have different or more + other models may have different or more complex logic for the default decoder prompt. This motivates having a special helper method for default decoder prompts. @@ -209,12 +199,12 @@ async def _tokenize_prompt_async( prompt=prompt, lora_request=lora_request) - def _extract_prompt_components( + def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: + ) -> DecoderOnlyInputs: ''' Extract the components of any single encoder or decoder input prompt. @@ -241,14 +231,24 @@ def _extract_prompt_components( request_id=request_id, lora_request=lora_request, ) - multi_modal_data = None - mm_processor_kwargs = None - elif parsed["type"] == "tokens": - prompt_text = None + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if parsed["type"] == "tokens": prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - elif parsed["type"] == "text": + + return token_inputs( + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + if parsed["type"] == "text": prompt_text = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( prompt_text, @@ -257,18 +257,22 @@ def _extract_prompt_components( ) multi_modal_data = parsed["content"].get("multi_modal_data") mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - else: - assert_never(parsed) - return (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + assert_never(parsed) - async def _extract_prompt_components_async( + async def _prompt_to_llm_inputs_async( self, prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: + ) -> DecoderOnlyInputs: """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(prompt) @@ -279,14 +283,24 @@ async def _extract_prompt_components_async( request_id=request_id, lora_request=lora_request, ) - multi_modal_data = None - mm_processor_kwargs = None - elif parsed["type"] == "tokens": - prompt_text = None + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if parsed["type"] == "tokens": prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - elif parsed["type"] == "text": + + return token_inputs( + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + if parsed["type"] == "text": prompt_text = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( prompt_text, @@ -295,43 +309,49 @@ async def _extract_prompt_components_async( ) multi_modal_data = parsed["content"].get("multi_modal_data") mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - else: - assert_never(parsed) - return (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + assert_never(parsed) def _build_enc_dec_llm_inputs( self, - encoder_comps: PromptComponents, - decoder_comps: DecoderPromptComponents, - mm_processor_kwargs: Dict[str, Any], + encoder_inputs: TokenInputs, + decoder_inputs: Optional[TokenInputs], ) -> EncoderDecoderInputs: - encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps - - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if decoder_mm_data is not None: - raise ValueError( - "Multi-modality decoder inputs of encoder-decoder models are " - "not supported yet") - - # For Multi-Modal models (e.g., mllama), the text input can be - # <|image|><|begin_of_text|>hello world. And we should not add - # another <|begin_of_text|> to the beginning. - decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation( - decoder_prompt_ids, - force_bos=(encoder_mm_data is None and decoder_mm_data is None))) + if encoder_inputs["type"] == "token": + pass + else: + assert_never(encoder_inputs) + + if decoder_inputs is None: + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + None, + force_bos="multi_modal_data" not in encoder_inputs, + ) + decoder_inputs = token_inputs(dec_token_ids) + elif decoder_inputs["type"] == "token": + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + decoder_inputs["prompt_token_ids"], + force_bos=("multi_modal_data" not in encoder_inputs + and "multi_modal_data" not in decoder_inputs), + ) + decoder_inputs["prompt_token_ids"] = dec_token_ids + + if "multi_modal_data" in decoder_inputs: + raise ValueError("Multi-modal decoder inputs of encoder-" + "decoder models are not supported yet") + else: + assert_never(encoder_inputs) return EncoderDecoderInputs( - prompt_token_ids=decoder_prompt_ids, - prompt=decoder_prompt, - multi_modal_data=decoder_mm_data, - mm_processor_kwargs=mm_processor_kwargs, - encoder_prompt_token_ids=encoder_prompt_ids, - encoder_prompt=encoder_prompt, - encoder_multi_modal_data=encoder_mm_data, + encoder=encoder_inputs, + decoder=decoder_inputs, ) def _process_encoder_decoder_prompt( @@ -341,8 +361,7 @@ def _process_encoder_decoder_prompt( ) -> EncoderDecoderInputs: ''' For encoder/decoder models only: - Process an input prompt into an - :class:`EncoderDecoderInputs` instance. + Process an input prompt into an :class:`EncoderDecoderInputs` instance. There are two types of input prompts: singleton prompts which carry only the @@ -361,7 +380,7 @@ def _process_encoder_decoder_prompt( have any possible singleton type; thus this method relies on helper functions to obtain token ids for the sub-prompts. - + Arguments: * prompt: an input prompt @@ -372,40 +391,31 @@ def _process_encoder_decoder_prompt( * :class:`EncoderDecoderInputs` instance ''' - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents + encoder_inputs: TokenInputs + decoder_inputs: Optional[TokenInputs] if is_explicit_encoder_decoder_prompt(prompt): - encoder_comps = self._extract_prompt_components( + encoder_inputs = self._prompt_to_llm_inputs( prompt["encoder_prompt"], request_id=request_id, ) if (decoder_input := prompt["decoder_prompt"]) is None: - decoder_comps = None, None, None, None + decoder_inputs = None else: - decoder_comps = self._extract_prompt_components( + decoder_inputs = self._prompt_to_llm_inputs( decoder_input, request_id=request_id, ) - # Handle this carefully in case it was directly initialized by user - mm_processor_kwargs = prompt.get("mm_processor_kwargs", {}) else: - encoder_comps = self._extract_prompt_components( + encoder_inputs = self._prompt_to_llm_inputs( prompt, request_id=request_id, ) - # If there are no decoder components, we assume the - # mm_processor_kwargs are in the encoder prompt - mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ - -1] is not None else {} - decoder_comps = None, None, None, None - - return self._build_enc_dec_llm_inputs( - encoder_comps, - decoder_comps, - mm_processor_kwargs, - ) + + decoder_inputs = None + + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) async def _process_encoder_decoder_prompt_async( self, @@ -413,59 +423,50 @@ async def _process_encoder_decoder_prompt_async( request_id: str, ) -> EncoderDecoderInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents + encoder_inputs: TokenInputs + decoder_inputs: Optional[TokenInputs] if is_explicit_encoder_decoder_prompt(prompt): - encoder_task = self._extract_prompt_components_async( + encoder_task = self._prompt_to_llm_inputs_async( prompt["encoder_prompt"], request_id=request_id, ) if (decoder_input := prompt["decoder_prompt"]) is None: - encoder_comps = await encoder_task - decoder_comps = None, None, None, None + encoder_inputs = await encoder_task + decoder_inputs = None else: - decoder_task = self._extract_prompt_components_async( + decoder_task = self._prompt_to_llm_inputs_async( decoder_input, request_id=request_id, ) - encoder_comps, decoder_comps = await asyncio.gather( + encoder_inputs, decoder_inputs = await asyncio.gather( encoder_task, decoder_task) - mm_processor_kwargs = prompt["mm_processor_kwargs"] else: - encoder_comps = await self._extract_prompt_components_async( + encoder_inputs = await self._prompt_to_llm_inputs_async( prompt, request_id=request_id, ) - # If there are no decoder components, we assume the - # mm_processor_kwargs are in the encoder prompt - mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ - -1] is not None else {} - decoder_comps = None, None, None, None - - return self._build_enc_dec_llm_inputs( - encoder_comps, - decoder_comps, - mm_processor_kwargs, - ) + + decoder_inputs = None + + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) def _build_decoder_only_llm_inputs( self, - prompt_comps: PromptComponents, + prompt_inputs: DecoderOnlyInputs, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: - (prompt, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) = prompt_comps - - prompt_token_ids = self._apply_prompt_adapter( - prompt_token_ids, prompt_adapter_request=prompt_adapter_request) + if prompt_inputs["type"] == "token": + prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( + prompt_inputs["prompt_token_ids"], + prompt_adapter_request=prompt_adapter_request, + ) + else: + assert_never(prompt_inputs) - return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs) + return prompt_inputs def _process_decoder_only_prompt( self, @@ -490,7 +491,7 @@ def _process_decoder_only_prompt( * :class:`DecoderOnlyInputs` instance ''' - prompt_comps = self._extract_prompt_components( + prompt_comps = self._prompt_to_llm_inputs( prompt, request_id=request_id, lora_request=lora_request, @@ -509,7 +510,7 @@ async def _process_decoder_only_prompt_async( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> DecoderOnlyInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" - prompt_comps = await self._extract_prompt_components_async( + prompt_comps = await self._prompt_to_llm_inputs_async( prompt, request_id=request_id, lora_request=lora_request, @@ -526,7 +527,7 @@ def preprocess( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: + ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of @@ -554,7 +555,7 @@ async def preprocess_async( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: + ) -> ProcessorInputs: """Async version of :meth:`preprocess`.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 4cebc91ce715c..41af8456f53a5 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -12,7 +12,7 @@ from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, resolve_mm_processor_kwargs) -from .data import DecoderOnlyInputs +from .data import ProcessorInputs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -100,7 +100,7 @@ def __getitem__(self, key: str) -> int: raise KeyError(msg) from exc -InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs] +InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs] """Preprocess the inputs to the model.""" @@ -248,8 +248,8 @@ def dummy_data_for_profiling( def _default_input_processor( self, ctx: InputContext, - inputs: DecoderOnlyInputs, - ) -> DecoderOnlyInputs: + inputs: ProcessorInputs, + ) -> ProcessorInputs: """The default input processor is a no-op.""" return inputs @@ -282,7 +282,7 @@ def _get_model_input_processor(self, model_cls: Type[nn.Module]): .get(model_cls, self._default_input_processor) def process_input(self, model_config: "ModelConfig", - inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: + inputs: ProcessorInputs) -> ProcessorInputs: """ Apply an input processor to an instance of model inputs. diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 378231f14455a..39034b4e9d664 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -36,8 +36,7 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderInputs, InputContext) +from vllm.inputs import INPUT_REGISTRY, EncoderDecoderInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -52,6 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SequenceData +from vllm.utils import is_list_of from .clip import CLIPMLP from .interfaces import SupportsMultiModal @@ -87,34 +87,37 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: def input_processor_for_mllama(ctx: InputContext, - inputs: Union[DecoderOnlyInputs, - EncoderDecoderInputs]): - # move encoder_prompt to prompt - if inputs.get("prompt") is None: - inputs["prompt"] = inputs["encoder_prompt"] - inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"] - - # process multi-modal data - multi_modal_data = inputs.get("encoder_multi_modal_data") - - if multi_modal_data is None or "image" not in multi_modal_data \ - or multi_modal_data["image"] is None: + inputs: EncoderDecoderInputs): + enc_inputs = inputs["encoder"] + dec_inputs = inputs["decoder"] + + # move encoder prompt to decoder + if dec_inputs.get("prompt") is None: + dec_inputs["prompt"] = enc_inputs["prompt"] + dec_inputs["prompt_token_ids"] = enc_inputs["prompt_token_ids"] + + multi_modal_data = enc_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: # text-only - inputs["encoder_prompt"] = "" - inputs["encoder_prompt_token_ids"] = [] - inputs["encoder_multi_modal_data"] = {} + enc_inputs["prompt"] = "" + enc_inputs["prompt_token_ids"] = [] + enc_inputs["multi_modal_data"] = {} return inputs - if isinstance(multi_modal_data['image'], Image.Image): - multi_modal_data['image'] = [multi_modal_data['image']] + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_data = [image_data] + + assert is_list_of(image_data, Image.Image) + # Since only the last group of consecutive images # are attended by the decoded tokens, we only need to # get the number of tiles for those images. num_decode_images = _get_num_image_in_last_group( - inputs["prompt_token_ids"]) + dec_inputs["prompt_token_ids"]) hf_config = ctx.model_config.hf_config num_tiles = 0 - for image in multi_modal_data["image"][::-1]: + for image in image_data[::-1]: width, height = image.size tile_size = hf_config.vision_config.image_size canvas_height, canvas_width = get_optimal_tiled_canvas( @@ -129,7 +132,6 @@ def input_processor_for_mllama(ctx: InputContext, num_decode_images -= 1 if num_decode_images == 0: break - # Set encoder prompt length based on the number of tiles. # This tells the block manager to allocate correct number # of slots for encoder tokens. @@ -137,8 +139,9 @@ def input_processor_for_mllama(ctx: InputContext, "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk - inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens - inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens + + enc_inputs["prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens + enc_inputs["prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens return inputs diff --git a/vllm/sequence.py b/vllm/sequence.py index 93f58f00ef77b..cc7e8f8fef1fd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -8,12 +8,12 @@ from functools import cached_property, reduce from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union, cast +from typing import Set, Tuple, Union import msgspec import torch +from typing_extensions import assert_never -from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -378,15 +378,10 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. - - The sequence is constructed from the :code:`SingletonInputs` instance - passed in through the :code:`inputs` constructor argument. - - For encoder/decoder models, SingletonInputs encapsulates both a - decoder and encoder prompt, creating an ambiguity about which - prompt to construct the sequence from. The `from_decoder_prompt` - constructor argument signals whether to construct the Sequence - from the SingletonInputs decoder prompt, or encoder prompt. + + The sequence is constructed from the :data:`DecoderOnlyInputs` + (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder) + instance passed in through the :code:`inputs` constructor argument. Args: seq_id: The ID of the sequence. @@ -396,10 +391,6 @@ class Sequence: eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. lora_request: LoRA request. prompt_adapter_request: Prompt Adapter request. - from_decoder_prompt: Construct Sequence from SingletonInputs decoder - prompt (True) or encoder prompt (False.) Must be - True for decoder-only model. - """ def __init__( @@ -410,7 +401,6 @@ def __init__( eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - from_decoder_prompt: bool = True, ) -> None: self.seq_id = seq_id self.inputs = inputs @@ -418,33 +408,6 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.from_decoder_prompt = from_decoder_prompt - - # For decoder-only models, a Sequence is constructed - # from an DecoderOnlyInputs instance (the `inputs` arg.) - # - # For encoder/decoder models the same `inputs` - # instance could be utilized to construct either an - # encoder sequence or a decoder sequence, because - # `DecoderOnlyInputs` has both decoder- and encoder-oriented - # member variables (i.e. it encapsulates both an encoder - # and a decoder prompt.) The decision of which type of sequence - # to generate is determined by the `from_decoder_prompt` argument. - # - # When constructing a encoder sequence - # (`from_decoder_prompt` False) it matters that - # the `DecoderOnlyInputs` instance stored in `inputs` is valid - # in the sense that its encoder-related member variables are - # populated; below, an exception is raised if this is - # not the case. - # - # When constructing a decoder sequence (`from_decoder_prompt` True) - # it does not matter whether `inputs` has its encoder-related - # member variables populated. - if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)): - raise ValueError("Cannot extract encoder input prompt from " - f"invalid input {inputs}; did you forget the " - "encoder input prompt fields?") self.data = SequenceData.from_seqs(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -469,41 +432,48 @@ def n_blocks(self) -> int: @cached_property def prompt(self) -> Optional[str]: - # Select decoder or encoder input prompt str, as appropriate - prompt_key: str = ("prompt" - if self.from_decoder_prompt else "encoder_prompt") + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("prompt") - return cast(Optional[str], self.inputs.get(prompt_key)) + assert_never(inputs) @cached_property def prompt_token_ids(self) -> List[int]: - # Select decoder or encoder input prompt token ids, as appropriate - prompt_token_ids_key: str = ("prompt_token_ids" - if self.from_decoder_prompt else - "encoder_prompt_token_ids") + inputs = self.inputs - # Cache computed prompt token ids - return cast(List[int], self.inputs.get(prompt_token_ids_key)) + if inputs["type"] == "token": + return inputs.get("prompt_token_ids", []) - @property + assert_never(inputs) + + @cached_property + def prompt_embeds(self) -> Optional[torch.Tensor]: + inputs = self.inputs + + if inputs["type"] == "token": + return None + + assert_never(inputs) + + @cached_property def multi_modal_data(self) -> "MultiModalDataDict": inputs = self.inputs - if (inputs.get("multi_modal_data") - and inputs.get("encoder_multi_modal_data")): - raise ValueError( - "Multi-modal data in both encoder and decoder is not supported." - ) + if inputs["type"] == "token": + return inputs.get("multi_modal_data", {}) - return cast( - "MultiModalDataDict", - (inputs.get("multi_modal_data") - or inputs.get("encoder_multi_modal_data") or {}), - ) + assert_never(inputs) - @property + @cached_property def mm_processor_kwargs(self) -> Dict[str, Any]: - return self.inputs.get("mm_processor_kwargs") or {} + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("mm_processor_kwargs", {}) + + assert_never(inputs) @property def lora_int_id(self) -> int: From fa5ad179b1010c0ba257050025e125a234369e7a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 05:11:57 +0000 Subject: [PATCH 61/88] Rename --- vllm/inputs/preprocess.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 9681c7dc548f6..73db916dfc2ed 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -11,7 +11,7 @@ from vllm.utils import print_warning_once from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, - PromptType, SingletonPrompt, TokenInputs, token_inputs) + PromptType, SingletonInputs, SingletonPrompt, token_inputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) @@ -204,7 +204,7 @@ def _prompt_to_llm_inputs( prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> DecoderOnlyInputs: + ) -> SingletonInputs: ''' Extract the components of any single encoder or decoder input prompt. @@ -272,7 +272,7 @@ async def _prompt_to_llm_inputs_async( prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> DecoderOnlyInputs: + ) -> SingletonInputs: """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(prompt) @@ -321,8 +321,8 @@ async def _prompt_to_llm_inputs_async( def _build_enc_dec_llm_inputs( self, - encoder_inputs: TokenInputs, - decoder_inputs: Optional[TokenInputs], + encoder_inputs: SingletonInputs, + decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: if encoder_inputs["type"] == "token": pass @@ -391,8 +391,8 @@ def _process_encoder_decoder_prompt( * :class:`EncoderDecoderInputs` instance ''' - encoder_inputs: TokenInputs - decoder_inputs: Optional[TokenInputs] + encoder_inputs: SingletonInputs + decoder_inputs: Optional[SingletonInputs] if is_explicit_encoder_decoder_prompt(prompt): encoder_inputs = self._prompt_to_llm_inputs( @@ -423,8 +423,8 @@ async def _process_encoder_decoder_prompt_async( request_id: str, ) -> EncoderDecoderInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" - encoder_inputs: TokenInputs - decoder_inputs: Optional[TokenInputs] + encoder_inputs: SingletonInputs + decoder_inputs: Optional[SingletonInputs] if is_explicit_encoder_decoder_prompt(prompt): encoder_task = self._prompt_to_llm_inputs_async( From 44fd058d17237abb4b3948cddd119ab4eb5f9564 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 05:12:49 +0000 Subject: [PATCH 62/88] Fix type error --- vllm/inputs/registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 41af8456f53a5..1b531b2ab5188 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -2,7 +2,7 @@ from collections import UserDict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, - Protocol, Tuple, Type) + Protocol, Tuple, Type, cast) from torch import nn from transformers import PretrainedConfig @@ -302,7 +302,7 @@ def process_input(self, model_config: "ModelConfig", # If it's empty, it'll fall back to the default kwarg values mm_processor_kwargs = resolve_mm_processor_kwargs( model_config.mm_processor_kwargs, - inputs.get("mm_processor_kwargs"), + cast(Dict[str, Any], inputs.get("mm_processor_kwargs")), processor, ) From d167df38b79f2332c5663ce8654411f2a4a0ba7e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 05:15:25 +0000 Subject: [PATCH 63/88] Fix bad merge --- vllm/inputs/preprocess.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index cba1e48ebecf0..e81b22b448d1a 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -12,7 +12,8 @@ from vllm.utils import print_warning_once from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, - PromptType, SingletonInputs, SingletonPrompt, token_inputs) + PromptType, SingletonInputs, SingletonPrompt, embed_inputs, + token_inputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) From b73a345757b9fc24c9f47c9e8edda65031a9e862 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 07:05:07 +0000 Subject: [PATCH 64/88] Fix test --- tests/engine/output_processor/test_stop_checker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py index 0d84443c51f99..cc14e8cbf75df 100644 --- a/tests/engine/output_processor/test_stop_checker.py +++ b/tests/engine/output_processor/test_stop_checker.py @@ -4,6 +4,7 @@ from transformers import PreTrainedTokenizer from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.inputs import token_inputs from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob, Sequence, SequenceStatus @@ -15,7 +16,7 @@ def sequence_with_eos(text: str, eos_token: str, """ seq = Sequence( seq_id=0, - inputs={"prompt_token_ids": []}, + inputs=token_inputs([]), block_size=16, eos_token_id=eos_token_id, ) From fa968b59b5a951893ce55aad538f9a5b1169c1aa Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 07:37:04 +0000 Subject: [PATCH 65/88] Fix llama-3.2 --- vllm/model_executor/models/mllama.py | 58 +++++++++++++++++++--------- 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 39034b4e9d664..a7e5f0d106fdb 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -36,7 +36,8 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, EncoderDecoderInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderInputs, InputContext, + TokenInputs, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -88,21 +89,32 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: def input_processor_for_mllama(ctx: InputContext, inputs: EncoderDecoderInputs): - enc_inputs = inputs["encoder"] - dec_inputs = inputs["decoder"] + # Example inputs when initially passed to processor: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000], + # }, + # } # move encoder prompt to decoder - if dec_inputs.get("prompt") is None: - dec_inputs["prompt"] = enc_inputs["prompt"] - dec_inputs["prompt_token_ids"] = enc_inputs["prompt_token_ids"] + inputs["decoder"] = TokenInputs(**inputs["encoder"]) + + dec_inputs = inputs["decoder"] - multi_modal_data = enc_inputs.get("multi_modal_data") + multi_modal_data = dec_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: # text-only - enc_inputs["prompt"] = "" - enc_inputs["prompt_token_ids"] = [] - enc_inputs["multi_modal_data"] = {} - return inputs + return EncoderDecoderInputs( + encoder=token_inputs([]), + decoder=dec_inputs, + ) image_data = multi_modal_data["image"] if isinstance(image_data, Image.Image): @@ -115,15 +127,18 @@ def input_processor_for_mllama(ctx: InputContext, # get the number of tiles for those images. num_decode_images = _get_num_image_in_last_group( dec_inputs["prompt_token_ids"]) + hf_config = ctx.model_config.hf_config + vision_config = hf_config.vision_config + num_tiles = 0 for image in image_data[::-1]: width, height = image.size - tile_size = hf_config.vision_config.image_size + tile_size = vision_config.image_size canvas_height, canvas_width = get_optimal_tiled_canvas( image_height=height, image_width=width, - max_image_tiles=hf_config.vision_config.max_num_tiles, + max_image_tiles=vision_config.max_num_tiles, tile_size=tile_size, ) num_tiles_height = canvas_height // tile_size @@ -132,18 +147,23 @@ def input_processor_for_mllama(ctx: InputContext, num_decode_images -= 1 if num_decode_images == 0: break + # Set encoder prompt length based on the number of tiles. # This tells the block manager to allocate correct number # of slots for encoder tokens. - assert hf_config.vision_config.image_size % 14 == 0, \ + assert vision_config.image_size % 14 == 0, \ "chunk size should be multiple of 14" - token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + token_per_chunk = (vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk - enc_inputs["prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens - enc_inputs["prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens - - return inputs + return EncoderDecoderInputs( + encoder=token_inputs( + prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens, + prompt=MLLAMA_IMAGE_TOKEN * num_tokens, + multi_modal_data=multi_modal_data, + ), + decoder=dec_inputs, + ) def get_max_mllama_image_tokens(ctx: InputContext) -> int: From 5ccc390f3de9642a3b6e89a3aac998cd958f3218 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 08:11:42 +0000 Subject: [PATCH 66/88] Fix wrong variable --- vllm/model_executor/models/gpt_bigcode.py | 2 +- vllm/model_executor/models/olmo.py | 2 +- vllm/model_executor/models/opt.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 4945faafb3573..6e54849f8a3c7 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -230,7 +230,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, self.wte, + inputs_embeds = get_inputs_embeds(input_ids, self.wte, inputs_embeds, inputs_embeds_masks) position_embeds = self.wpe(position_ids) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 9b855b72fda69..74987d29cbb17 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -261,7 +261,7 @@ def forward( if get_pp_group().is_first_rank: # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds = get_inputs_embeds(input_ids, self.embed_tokens, inputs_embeds, inputs_embeds_masks) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 504e1ab982ca4..7902adb813449 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -255,7 +255,7 @@ def forward( inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = get_inputs_embeds(input_ids, + inputs_embeds = get_inputs_embeds(input_ids, self.get_input_embeddings, inputs_embeds, inputs_embeds_masks) From 2fe159c0cfc2234f15ab7a97ae017717d6642153 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 08:15:15 +0000 Subject: [PATCH 67/88] Impl get_inputs_embeds --- vllm/model_executor/models/solar.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index b9298ed031144..0698befec8b62 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -52,7 +52,7 @@ from vllm.utils import is_hip from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -319,12 +319,14 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) + residual = None else: assert intermediate_tensors is not None @@ -462,9 +464,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return model_output def compute_logits(self, hidden_states: torch.Tensor, From 7986553831ac9ffb6ae1b8ab4c9802f45884665f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 10:00:04 +0000 Subject: [PATCH 68/88] Fix tests --- tests/worker/test_model_runner.py | 8 ++++---- vllm/sequence.py | 2 +- vllm/worker/model_runner.py | 17 +++++++++-------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index d4677d0caf281..d64b63091fdca 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -52,7 +52,7 @@ def test_prepare_prompt(batch_size, prompt_embeds_ratio): if random.random() < prompt_embeds_ratio: seq_data = SequenceData.from_seqs( range(seq_len), - prompt_embeds=torch.rand(seq_len, 10).tolist(), + prompt_embeds=torch.rand(seq_len, 10), ) input_embeds_len += seq_len else: @@ -185,7 +185,7 @@ def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): if random.random() < prompt_embeds_ratio: seq_data = SequenceData.from_seqs( [], - prompt_embeds=torch.rand(context_len, 10).tolist(), + prompt_embeds=torch.rand(context_len, 10), ) input_embeds_len += context_len else: @@ -365,7 +365,7 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, if random.random() < prompt_embeds_ratio: seq_data = SequenceData.from_seqs( [], - prompt_embeds=torch.rand(seq_len, 10).tolist(), + prompt_embeds=torch.rand(seq_len, 10), ) input_embeds_len += seq_len else: @@ -388,7 +388,7 @@ def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, if random.random() < prompt_embeds_ratio: seq_data = SequenceData.from_seqs( [], - prompt_embeds=torch.rand(context_len, 10).tolist(), + prompt_embeds=torch.rand(context_len, 10), ), else: seq_data = SequenceData.from_seqs(range(context_len)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 2d14f0db64df9..7186c2bc23413 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -426,7 +426,7 @@ def __init__( if self.prompt_token_ids: data = SequenceData.from_seqs(self.prompt_token_ids) else: - assert self.prompt_embeds is not None + assert isinstance(self.prompt_embeds, torch.Tensor) data = SequenceData.from_seqs([], prompt_embeds=self.prompt_embeds) self.data = data diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fabd404c23441..90e4646675ea1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -831,11 +831,20 @@ def build(self) -> ModelInputForGPU: seq_lens = [] query_lens = [] + input_embeds_lst = [] + input_embeds_masks_lst = [] max_decode_seq_len = 0 max_encoder_seq_len = 0 + for inter_data in self.inter_data_list: seq_lens.extend(inter_data.seq_lens) query_lens.extend(inter_data.query_lens) + + if inter_data.input_embeds is not None: + input_embeds_lst.append(inter_data.input_embeds) + if inter_data.input_embeds_mask is not None: + input_embeds_masks_lst.append(inter_data.input_embeds_mask) + if not inter_data.is_prompt: max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens)) @@ -843,10 +852,6 @@ def build(self) -> ModelInputForGPU: max_encoder_seq_len = max(max_encoder_seq_len, inter_data.encoder_seq_len) - input_embeds_lst = [ - inter_data.input_embeds for inter_data in self.inter_data_list - if inter_data.input_embeds is not None - ] if input_embeds_lst: input_embeds = torch.cat(input_embeds_lst).to( device=self.runner.device, @@ -854,10 +859,6 @@ def build(self) -> ModelInputForGPU: else: input_embeds = None - input_embeds_masks_lst = [ - inter_data.input_embeds_mask for inter_data in self.inter_data_list - if inter_data.input_embeds_mask is not None - ] if input_embeds_masks_lst: input_embeds_masks = torch.cat(input_embeds_masks_lst).to( self.runner.device) From 4c072caf1745739772e43c4ddbe2cb315fe86c77 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 10:03:48 +0000 Subject: [PATCH 69/88] format --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 90e4646675ea1..9a628e1746fdf 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -839,7 +839,7 @@ def build(self) -> ModelInputForGPU: for inter_data in self.inter_data_list: seq_lens.extend(inter_data.seq_lens) query_lens.extend(inter_data.query_lens) - + if inter_data.input_embeds is not None: input_embeds_lst.append(inter_data.input_embeds) if inter_data.input_embeds_mask is not None: From 3b22bbcb146b8174fa7b64db801a4324a23b017d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 23 Oct 2024 10:34:39 +0000 Subject: [PATCH 70/88] Don't use `prompt_embeds` in `get_len` --- vllm/sequence.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 7186c2bc23413..c524fbe5eb641 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -242,18 +242,10 @@ def cumulative_logprob(self) -> float: def prompt_token_ids(self) -> Tuple[int, ...]: return self._prompt_token_ids_tuple - @prompt_token_ids.setter - def prompt_token_ids(self, new_prompt_token_ids) -> None: - raise NotImplementedError - @property def prompt_embeds(self) -> Optional[torch.Tensor]: return self._prompt_embeds - @prompt_embeds.setter - def prompt_embeds(self, new_prompt_embeds: Optional[torch.Tensor]) -> None: - self._prompt_embeds = new_prompt_embeds - @property def prompt_token_ids_array(self) -> array: """Return the prompt token ids in array type. @@ -297,10 +289,7 @@ def append_token_id(self, token_id: int, logprob: float) -> None: self._cumulative_logprob += logprob def get_len(self) -> int: - if self._prompt_embeds is None: - return len(self._output_token_ids) + len(self._prompt_token_ids) - else: - return len(self._output_token_ids) + len(self._prompt_embeds) + return len(self._output_token_ids) + len(self._prompt_token_ids) def get_prompt_len(self) -> int: return len(self._prompt_token_ids) From 906ee1ea9e79f46bd652abcdbf54422a3b5a05c3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 24 Oct 2024 02:29:50 +0000 Subject: [PATCH 71/88] Remove force_bos --- vllm/inputs/preprocess.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 73db916dfc2ed..59441ecfd46a1 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -122,7 +122,6 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: def _prepare_decoder_input_ids_for_generation( self, decoder_input_ids: Optional[List[int]], - force_bos: bool = True, ) -> List[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -152,10 +151,6 @@ def _prepare_decoder_input_ids_for_generation( # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - if force_bos and (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): - decoder_input_ids = [decoder_start_token_id] + decoder_input_ids - return decoder_input_ids def _apply_prompt_adapter( @@ -330,17 +325,11 @@ def _build_enc_dec_llm_inputs( assert_never(encoder_inputs) if decoder_inputs is None: - dec_token_ids = self._prepare_decoder_input_ids_for_generation( - None, - force_bos="multi_modal_data" not in encoder_inputs, - ) + dec_token_ids = self._prepare_decoder_input_ids_for_generation(None) decoder_inputs = token_inputs(dec_token_ids) elif decoder_inputs["type"] == "token": dec_token_ids = self._prepare_decoder_input_ids_for_generation( - decoder_inputs["prompt_token_ids"], - force_bos=("multi_modal_data" not in encoder_inputs - and "multi_modal_data" not in decoder_inputs), - ) + decoder_inputs["prompt_token_ids"]) decoder_inputs["prompt_token_ids"] = dec_token_ids if "multi_modal_data" in decoder_inputs: From 005ad95003ff57af6f921ac33a3507780ac756d2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 24 Oct 2024 02:32:46 +0000 Subject: [PATCH 72/88] Add example output --- vllm/model_executor/models/mllama.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index a7e5f0d106fdb..845700bc49d4d 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -89,7 +89,7 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: def input_processor_for_mllama(ctx: InputContext, inputs: EncoderDecoderInputs): - # Example inputs when initially passed to processor: + # Example input to processor: # { # 'encoder': { # 'type': 'token', @@ -156,6 +156,21 @@ def input_processor_for_mllama(ctx: InputContext, token_per_chunk = (vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk + # Example output from processor: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128256, 128256, ..., 128256], + # 'prompt': '<|image|><|image|>...<|image|>', + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # } return EncoderDecoderInputs( encoder=token_inputs( prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens, From a5f0c163ad810f5bbaae465956f68e16cbb46092 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 24 Oct 2024 02:47:12 +0000 Subject: [PATCH 73/88] format --- vllm/inputs/preprocess.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 59441ecfd46a1..4bdbcb88e653b 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -325,7 +325,8 @@ def _build_enc_dec_llm_inputs( assert_never(encoder_inputs) if decoder_inputs is None: - dec_token_ids = self._prepare_decoder_input_ids_for_generation(None) + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + None) decoder_inputs = token_inputs(dec_token_ids) elif decoder_inputs["type"] == "token": dec_token_ids = self._prepare_decoder_input_ids_for_generation( From 6ab44e4829278d17039651572a0708e3b5f21d97 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 24 Oct 2024 04:24:37 +0000 Subject: [PATCH 74/88] Fix --- vllm/inputs/preprocess.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 4bdbcb88e653b..c501b5490c91c 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -151,6 +151,10 @@ def _prepare_decoder_input_ids_for_generation( # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + if (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): + decoder_input_ids = [decoder_start_token_id] + decoder_input_ids + return decoder_input_ids def _apply_prompt_adapter( From 760db0549b82b51e83cd3e4905ac99475c6863a4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 31 Oct 2024 01:31:54 +0000 Subject: [PATCH 75/88] Fix merge --- vllm/inputs/parse.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index b11a151c4a585..3438effe6d4c8 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -98,6 +98,10 @@ def parse_singleton_prompt( raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") +def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: + return isinstance(prompt, dict) and "prompt_token_ids" in prompt + + def is_explicit_encoder_decoder_prompt( prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: return isinstance(prompt, dict) and "encoder_prompt" in prompt @@ -105,4 +109,4 @@ def is_explicit_encoder_decoder_prompt( def is_encoder_decoder_inputs( inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: - return "encoder" in inputs and "decoder" in inputs + return "encoder" in inputs and "decoder" in inputs \ No newline at end of file From acb8e6f160fb744c68fa46e18f24a76341453f49 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 31 Oct 2024 01:32:26 +0000 Subject: [PATCH 76/88] Update mllama processing --- vllm/model_executor/models/mllama.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index e26c5f5dfb681..4ecd96d3e9159 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -87,8 +87,10 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: return num_images -def input_processor_for_mllama(ctx: InputContext, - inputs: EncoderDecoderInputs): +def input_processor_for_mllama( + ctx: InputContext, + inputs: EncoderDecoderInputs, +)-> EncoderDecoderInputs: # Example input to processor: # { # 'encoder': { @@ -104,9 +106,7 @@ def input_processor_for_mllama(ctx: InputContext, # } # move encoder prompt to decoder - inputs["decoder"] = TokenInputs(**inputs["encoder"]) - - dec_inputs = inputs["decoder"] + dec_inputs = TokenInputs(**inputs["encoder"]) multi_modal_data = dec_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: From 3bed51909b3bddff3e84a5c0c2f0931e51ef8d32 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 31 Oct 2024 01:32:42 +0000 Subject: [PATCH 77/88] Fix line --- vllm/inputs/parse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 3438effe6d4c8..09f1ff2cb42e9 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -109,4 +109,4 @@ def is_explicit_encoder_decoder_prompt( def is_encoder_decoder_inputs( inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: - return "encoder" in inputs and "decoder" in inputs \ No newline at end of file + return "encoder" in inputs and "decoder" in inputs From ea861e015ce5f365fb958f981238a43ef56bc984 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 31 Oct 2024 01:33:55 +0000 Subject: [PATCH 78/88] format --- vllm/model_executor/models/mllama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 4ecd96d3e9159..83899c3b200ea 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -90,7 +90,7 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: def input_processor_for_mllama( ctx: InputContext, inputs: EncoderDecoderInputs, -)-> EncoderDecoderInputs: +) -> EncoderDecoderInputs: # Example input to processor: # { # 'encoder': { From f654421cd980452a51e53e95d0271181357f0d16 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 31 Oct 2024 01:36:56 +0000 Subject: [PATCH 79/88] Avoid repeated lookups --- vllm/inputs/preprocess.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index c501b5490c91c..453ff36bd6553 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -237,9 +237,11 @@ def _prompt_to_llm_inputs( ) if parsed["type"] == "tokens": - prompt_token_ids = parsed["content"]["prompt_token_ids"] - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") + content = parsed["content"] + + prompt_token_ids = content["prompt_token_ids"] + multi_modal_data = content.get("multi_modal_data") + mm_processor_kwargs = content.get("mm_processor_kwargs") return token_inputs( prompt_token_ids=prompt_token_ids, @@ -248,14 +250,16 @@ def _prompt_to_llm_inputs( ) if parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + content = parsed["content"] + + prompt_text = content["prompt"] prompt_token_ids = self._tokenize_prompt( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") + multi_modal_data = content.get("multi_modal_data") + mm_processor_kwargs = content.get("mm_processor_kwargs") return token_inputs( prompt=prompt_text, @@ -289,9 +293,11 @@ async def _prompt_to_llm_inputs_async( ) if parsed["type"] == "tokens": - prompt_token_ids = parsed["content"]["prompt_token_ids"] - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") + content = parsed["content"] + + prompt_token_ids = content["prompt_token_ids"] + multi_modal_data = content.get("multi_modal_data") + mm_processor_kwargs = content.get("mm_processor_kwargs") return token_inputs( prompt_token_ids=prompt_token_ids, @@ -300,14 +306,16 @@ async def _prompt_to_llm_inputs_async( ) if parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + content = parsed["content"] + + prompt_text = content["prompt"] prompt_token_ids = await self._tokenize_prompt_async( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") + multi_modal_data = content.get("multi_modal_data") + mm_processor_kwargs = content.get("mm_processor_kwargs") return token_inputs( prompt=prompt_text, From 594794e24344f1be5c0386f7bf1ca12d5f421a87 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 31 Oct 2024 01:37:47 +0000 Subject: [PATCH 80/88] Remove unused import --- vllm/engine/llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8d6c3bc725693..78c5baf840532 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -10,7 +10,7 @@ from typing import Set, Type, Union, cast, overload import torch -from typing_extensions import TypeIs, TypeVar +from typing_extensions import TypeVar import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, From 08ea824121160054da226760aeb27200a37c1fed Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 31 Oct 2024 01:39:02 +0000 Subject: [PATCH 81/88] Fix mypy --- vllm/inputs/preprocess.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 453ff36bd6553..a5c787a56b5a9 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -237,11 +237,11 @@ def _prompt_to_llm_inputs( ) if parsed["type"] == "tokens": - content = parsed["content"] + tokens_content = parsed["content"] - prompt_token_ids = content["prompt_token_ids"] - multi_modal_data = content.get("multi_modal_data") - mm_processor_kwargs = content.get("mm_processor_kwargs") + prompt_token_ids = tokens_content["prompt_token_ids"] + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") return token_inputs( prompt_token_ids=prompt_token_ids, @@ -250,16 +250,16 @@ def _prompt_to_llm_inputs( ) if parsed["type"] == "text": - content = parsed["content"] + text_content = parsed["content"] - prompt_text = content["prompt"] + prompt_text = text_content["prompt"] prompt_token_ids = self._tokenize_prompt( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = content.get("multi_modal_data") - mm_processor_kwargs = content.get("mm_processor_kwargs") + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") return token_inputs( prompt=prompt_text, @@ -293,11 +293,11 @@ async def _prompt_to_llm_inputs_async( ) if parsed["type"] == "tokens": - content = parsed["content"] + tokens_content = parsed["content"] - prompt_token_ids = content["prompt_token_ids"] - multi_modal_data = content.get("multi_modal_data") - mm_processor_kwargs = content.get("mm_processor_kwargs") + prompt_token_ids = tokens_content["prompt_token_ids"] + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") return token_inputs( prompt_token_ids=prompt_token_ids, @@ -306,16 +306,16 @@ async def _prompt_to_llm_inputs_async( ) if parsed["type"] == "text": - content = parsed["content"] + text_content = parsed["content"] - prompt_text = content["prompt"] + prompt_text = text_content["prompt"] prompt_token_ids = await self._tokenize_prompt_async( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = content.get("multi_modal_data") - mm_processor_kwargs = content.get("mm_processor_kwargs") + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") return token_inputs( prompt=prompt_text, From 283bc2ccdd28bc02cab567631b326d59ae6e7f30 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 31 Oct 2024 10:00:31 +0000 Subject: [PATCH 82/88] Fix merge --- vllm/engine/protocol.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 6a09361c56865..e0b59d94cfdc3 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,11 +1,12 @@ import asyncio from abc import ABC, abstractmethod -from typing import AsyncGenerator, List, Mapping, Optional, Union +from typing import AsyncGenerator, List, Mapping, Optional from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -60,7 +61,7 @@ def generate( async def beam_search( self, - prompt: Union[PromptType, List[int]], + prompt: PromptType, model_config: ModelConfig, request_id: str, params: BeamSearchParams, @@ -76,11 +77,19 @@ async def beam_search( tokenizer = await self.get_tokenizer() input_preprocessor = InputPreprocessor(model_config, tokenizer) - (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) = input_preprocessor._extract_prompt_components( - prompt, - request_id=request_id, - ) + if is_explicit_encoder_decoder_prompt(prompt): + raise NotImplementedError + else: + processed_inputs = input_preprocessor._prompt_to_llm_inputs( + prompt, + request_id=request_id, + ) + + prompt_token_ids = processed_inputs["prompt_token_ids"] + prompt_text = processed_inputs.get("prompt") + multi_modal_data = processed_inputs.get("multi_modal_data") + mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs") + tokenized_length = len(prompt_token_ids) sort_beams_key = create_sort_beams_key_function( From b45cdc974c22d2022112119e846d7c16e4bc9085 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 3 Nov 2024 03:20:21 +0000 Subject: [PATCH 83/88] Fix missing import Signed-off-by: DarkLight1337 --- vllm/inputs/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index d9e853020ee76..b19e419b9ca84 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -3,7 +3,7 @@ SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import InputContext, InputRegistry +from .registry import DummyData, InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() """ From 4d33b1ec4d396bdbe889692827e4120348dba71b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 3 Nov 2024 03:25:44 +0000 Subject: [PATCH 84/88] Improve error message Signed-off-by: DarkLight1337 --- vllm/model_executor/models/registry.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f50ceaccb1bbe..02f2215ceaabd 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -342,9 +342,13 @@ def register_model( def _raise_for_unsupported(self, architectures: List[str]): all_supported_archs = self.get_supported_archs() - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {all_supported_archs}") + msg = (f"Model architectures {architectures} are not supported for " + f"now. Supported architectures: {all_supported_archs}") + if any(arch in all_supported_archs for arch in architectures): + msg += ("\n(Please check the logs to see why the model " + "failed to be inspected.)") + + raise ValueError(msg) def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]: From 0a549e54ac9524a08b07ccd4846f96bf2b2ffb25 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 3 Nov 2024 03:28:10 +0000 Subject: [PATCH 85/88] Add missing export Signed-off-by: DarkLight1337 --- vllm/inputs/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index b19e419b9ca84..68ac50a2c5a16 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -22,6 +22,7 @@ "ExplicitEncoderDecoderPrompt", "TokenInputs", "token_inputs", + "DecoderOnlyInputs", "EncoderDecoderInputs", "ProcessorInputs", "SingletonInputs", From f741a75af1ca13bba44727d43963891e6fa473df Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 3 Nov 2024 11:46:28 +0800 Subject: [PATCH 86/88] Improve error message. --- vllm/model_executor/models/registry.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 02f2215ceaabd..b3870244cb001 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -342,13 +342,16 @@ def register_model( def _raise_for_unsupported(self, architectures: List[str]): all_supported_archs = self.get_supported_archs() - msg = (f"Model architectures {architectures} are not supported for " - f"now. Supported architectures: {all_supported_archs}") if any(arch in all_supported_archs for arch in architectures): - msg += ("\n(Please check the logs to see why the model " - "failed to be inspected.)") - - raise ValueError(msg) + raise ValueError( + f"Model architectures {architectures} failed " + "to be inspected. Please check the logs for more details." + ) + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {all_supported_archs}" + ) def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]: From cd231fa780a14f99ed337a6f550cb2b269991a9e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 3 Nov 2024 11:47:35 +0800 Subject: [PATCH 87/88] Format --- vllm/model_executor/models/registry.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b3870244cb001..1fd20307d92db 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -345,13 +345,11 @@ def _raise_for_unsupported(self, architectures: List[str]): if any(arch in all_supported_archs for arch in architectures): raise ValueError( f"Model architectures {architectures} failed " - "to be inspected. Please check the logs for more details." - ) + "to be inspected. Please check the logs for more details.") raise ValueError( f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {all_supported_archs}" - ) + f"Supported architectures: {all_supported_archs}") def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]: From c8fc1feada91c0624ca6d3c66c7f5a28e816fd7b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 6 Nov 2024 09:21:55 +0000 Subject: [PATCH 88/88] Update `get_inputs_embeds` to be compatible with `torch.compile` Signed-off-by: DarkLight1337 --- vllm/model_executor/models/utils.py | 2 +- vllm/sequence.py | 19 ++++++++++++++++--- vllm/worker/model_runner.py | 12 ++++++++++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 48c1cb0da6052..6013eae63b288 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -572,7 +572,7 @@ def get_inputs_embeds( ) -> torch.Tensor: """Get the input embeddings from either `input_ids` and `inputs_embeds`.""" if inputs_embeds is not None: - if inputs_embeds_masks is None or inputs_embeds_masks.all().item(): + if inputs_embeds_masks is None: hidden_states = inputs_embeds else: msg = "inputs_embeds should not be masked out for multimodal models" diff --git a/vllm/sequence.py b/vllm/sequence.py index d51a5ae26a99a..5cc2f15aa406f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -9,7 +9,7 @@ from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, Mapping, Optional) from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union +from typing import Set, Tuple, Union, overload import msgspec import torch @@ -1125,12 +1125,25 @@ class IntermediateTensors: tensors: Dict[str, torch.Tensor] - def __getitem__(self, key: Union[str, slice]): + @overload + def __getitem__(self, key: str) -> torch.Tensor: + ... + + @overload + def __getitem__(self, key: slice) -> "IntermediateTensors": + ... + + def __getitem__( + self, + key: Union[str, slice], + ) -> Union[torch.Tensor, "IntermediateTensors"]: if isinstance(key, str): return self.tensors[key] - elif isinstance(key, slice): + if isinstance(key, slice): return self.__class__({k: v[key] for k, v in self.tensors.items()}) + assert_never(key) + def __setitem__(self, key: str, value): self.tensors[key] = value diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 102d2f3fa2497..006089bf19e02 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1710,9 +1710,17 @@ def execute_model( multi_modal_kwargs, device=self.device), **seqlen_agnostic_kwargs) if self.model_supports_input_embeds: + input_embeds = model_input.input_embeds + + input_embeds_masks = model_input.input_embeds_masks + if (input_embeds_masks is not None + and input_embeds_masks.all().item()): + input_embeds_masks = None + model_params.update( - inputs_embeds=model_input.input_embeds, - inputs_embeds_masks=model_input.input_embeds_masks) + inputs_embeds=input_embeds, + inputs_embeds_masks=input_embeds_masks, + ) hidden_or_intermediate_states = model_executable(**model_params)