diff --git a/tests/engine/test_options.py b/tests/engine/test_options.py new file mode 100644 index 000000000000..0cf4f69d56a8 --- /dev/null +++ b/tests/engine/test_options.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +from contextlib import nullcontext + +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.sampling_params import SamplingParams + + +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) +def test_skip_tokenizer_initialization(model: str): + # This test checks if the flag skip_tokenizer_init skips the initialization + # of tokenizer and detokenizer. The generated output is expected to contain + # token ids. + llm = LLM( + model=model, + skip_tokenizer_init=True, + enforce_eager=True, + ) + sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) + + with pytest.raises(ValueError, match="cannot pass text prompts when"): + llm.generate("abc", sampling_params) + + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, + sampling_params=sampling_params) + assert len(outputs) > 0 + completions = outputs[0].outputs + assert len(completions) > 0 + assert completions[0].text == "" + assert completions[0].token_ids + + +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) +@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) +def test_enable_prompt_embeds(hf_runner, model: str, + enable_prompt_embeds: bool): + prompt = "abc" + + with hf_runner(model) as hf_model: + token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids + token_ids = token_ids.to(hf_model.model.device) + + embed_layer = hf_model.model.get_input_embeddings() + prompt_embeds = embed_layer(token_ids).squeeze(0) + + ctx = (nullcontext() if enable_prompt_embeds else pytest.raises( + ValueError, match="set `--enable-prompt-embeds`")) + + # This test checks if the flag skip_tokenizer_init skips the initialization + # of tokenizer and detokenizer. The generated output is expected to contain + # token ids. + llm = LLM( + model=model, + enable_prompt_embeds=enable_prompt_embeds, + enforce_eager=True, + ) + + with ctx: + llm.generate({"prompt_embeds": prompt_embeds}) diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py deleted file mode 100644 index 5e197f5ffe59..000000000000 --- a/tests/engine/test_skip_tokenizer_init.py +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from vllm.entrypoints.llm import LLM -from vllm.sampling_params import SamplingParams - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_skip_tokenizer_initialization(model: str): - # This test checks if the flag skip_tokenizer_init skips the initialization - # of tokenizer and detokenizer. The generated output is expected to contain - # token ids. - llm = LLM( - model=model, - skip_tokenizer_init=True, - ) - sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) - - with pytest.raises(ValueError, match="cannot pass text prompts when"): - llm.generate("abc", sampling_params) - - outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, - sampling_params=sampling_params) - assert len(outputs) > 0 - completions = outputs[0].outputs - assert len(completions) > 0 - assert completions[0].text == "" - assert completions[0].token_ids diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index fcd3fa036cfd..c755593c9acb 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -109,12 +109,15 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, # in parts of the operators pytest.skip(f"Skipping '{model}' model test with AITER kernel.") + use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0" + with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv( - "VLLM_USE_V1") == "0" else None + prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds + else None) + prompt_token_ids = [] for prompt in example_prompts: token_ids = hf_model.tokenizer(prompt, @@ -131,6 +134,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, tokenizer_mode=model_info.tokenizer_mode, trust_remote_code=model_info.trust_remote_code, max_num_seqs=2, + enable_prompt_embeds=use_prompt_embeds, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index a1bdea687a85..ae4b536524be 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -43,6 +43,7 @@ def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch): max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, + enable_prompt_embeds=True, ) seq_lens: list[int] = [] @@ -179,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, + enable_prompt_embeds=True, ) context_lens: list[int] = [] @@ -359,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds, max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=True, + enable_prompt_embeds=True, ) # Add prefill requests. diff --git a/vllm/config.py b/vllm/config.py index 1ae8673f7775..91ef9dcdbd56 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -321,6 +321,10 @@ class ModelConfig: """Skip initialization of tokenizer and detokenizer. Expects valid `prompt_token_ids` and `None` for prompt from the input. The generated output will contain token ids.""" + enable_prompt_embeds: bool = False + """If `True`, enables passing text embeddings as inputs via the + `prompt_embeds` key. Note that enabling this will double the time required + for graph compilation.""" served_model_name: Optional[Union[str, list[str]]] = None """The model name(s) used in the API. If multiple names are provided, the server will respond to any of the provided names. The model name in the diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index aefba620e189..f6f8fb69fb70 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -234,6 +234,7 @@ class EngineArgs: hf_config_path: Optional[str] = ModelConfig.hf_config_path task: TaskOption = ModelConfig.task skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init + enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path @@ -445,6 +446,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["disable_cascade_attn"]) model_group.add_argument("--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]) + model_group.add_argument("--enable-prompt-embeds", + **model_kwargs["enable_prompt_embeds"]) model_group.add_argument("--served-model-name", **model_kwargs["served_model_name"]) # This one is a special case because it is the @@ -874,6 +877,7 @@ def create_model_config(self) -> ModelConfig: disable_sliding_window=self.disable_sliding_window, disable_cascade_attn=self.disable_cascade_attn, skip_tokenizer_init=self.skip_tokenizer_init, + enable_prompt_embeds=self.enable_prompt_embeds, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, use_async_output_proc=not self.disable_async_output_proc, diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 97a2ce5c615e..53e0a477a12d 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -303,8 +303,11 @@ def _process_embeds( self, parsed_content: EmbedsPrompt, ) -> EmbedsInputs: + if not self.model_config.enable_prompt_embeds: + raise ValueError("You must set `--enable-prompt-embeds` to input " + "`prompt_embeds`.") if envs.VLLM_USE_V1: - raise ValueError("prompt_embeds is only available in V0.") + raise ValueError("`prompt_embeds` is only available in V0.") prompt_embeds = parsed_content["prompt_embeds"] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 85814e9af9e3..e22bbcc656ff 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1565,7 +1565,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # product. cudagraph_capture_sizes = self.vllm_config.compilation_config\ .cudagraph_capture_sizes - cudagraph_inputs_embeds = (True, False) + cudagraph_inputs_embeds = (( + True, False) if self.model_config.enable_prompt_embeds else + (False, )) compilation_cases = itertools.product( cudagraph_capture_sizes, cudagraph_inputs_embeds,