From 9b43544228180ad51e1e0b8220d006329525628b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 3 May 2025 05:14:52 +0000 Subject: [PATCH 1/3] [Core] Gate `prompt_embeds` behind a feature flag Signed-off-by: DarkLight1337 --- tests/engine/test_options.py | 60 +++++++++++++++++++ tests/engine/test_skip_tokenizer_init.py | 29 --------- .../models/language/generation/test_common.py | 8 ++- tests/worker/test_model_runner.py | 3 + vllm/config.py | 3 + vllm/engine/arg_utils.py | 2 + vllm/inputs/preprocess.py | 5 +- vllm/worker/model_runner.py | 4 +- 8 files changed, 81 insertions(+), 33 deletions(-) create mode 100644 tests/engine/test_options.py delete mode 100644 tests/engine/test_skip_tokenizer_init.py diff --git a/tests/engine/test_options.py b/tests/engine/test_options.py new file mode 100644 index 000000000000..0cf4f69d56a8 --- /dev/null +++ b/tests/engine/test_options.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +from contextlib import nullcontext + +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.sampling_params import SamplingParams + + +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) +def test_skip_tokenizer_initialization(model: str): + # This test checks if the flag skip_tokenizer_init skips the initialization + # of tokenizer and detokenizer. The generated output is expected to contain + # token ids. + llm = LLM( + model=model, + skip_tokenizer_init=True, + enforce_eager=True, + ) + sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) + + with pytest.raises(ValueError, match="cannot pass text prompts when"): + llm.generate("abc", sampling_params) + + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, + sampling_params=sampling_params) + assert len(outputs) > 0 + completions = outputs[0].outputs + assert len(completions) > 0 + assert completions[0].text == "" + assert completions[0].token_ids + + +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) +@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) +def test_enable_prompt_embeds(hf_runner, model: str, + enable_prompt_embeds: bool): + prompt = "abc" + + with hf_runner(model) as hf_model: + token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids + token_ids = token_ids.to(hf_model.model.device) + + embed_layer = hf_model.model.get_input_embeddings() + prompt_embeds = embed_layer(token_ids).squeeze(0) + + ctx = (nullcontext() if enable_prompt_embeds else pytest.raises( + ValueError, match="set `--enable-prompt-embeds`")) + + # This test checks if the flag skip_tokenizer_init skips the initialization + # of tokenizer and detokenizer. The generated output is expected to contain + # token ids. + llm = LLM( + model=model, + enable_prompt_embeds=enable_prompt_embeds, + enforce_eager=True, + ) + + with ctx: + llm.generate({"prompt_embeds": prompt_embeds}) diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py deleted file mode 100644 index 5e197f5ffe59..000000000000 --- a/tests/engine/test_skip_tokenizer_init.py +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from vllm.entrypoints.llm import LLM -from vllm.sampling_params import SamplingParams - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_skip_tokenizer_initialization(model: str): - # This test checks if the flag skip_tokenizer_init skips the initialization - # of tokenizer and detokenizer. The generated output is expected to contain - # token ids. - llm = LLM( - model=model, - skip_tokenizer_init=True, - ) - sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) - - with pytest.raises(ValueError, match="cannot pass text prompts when"): - llm.generate("abc", sampling_params) - - outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, - sampling_params=sampling_params) - assert len(outputs) > 0 - completions = outputs[0].outputs - assert len(completions) > 0 - assert completions[0].text == "" - assert completions[0].token_ids diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index fcd3fa036cfd..c755593c9acb 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -109,12 +109,15 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, # in parts of the operators pytest.skip(f"Skipping '{model}' model test with AITER kernel.") + use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0" + with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv( - "VLLM_USE_V1") == "0" else None + prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds + else None) + prompt_token_ids = [] for prompt in example_prompts: token_ids = hf_model.tokenizer(prompt, @@ -131,6 +134,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, tokenizer_mode=model_info.tokenizer_mode, trust_remote_code=model_info.trust_remote_code, max_num_seqs=2, + enable_prompt_embeds=use_prompt_embeds, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index a1bdea687a85..ae4b536524be 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -43,6 +43,7 @@ def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch): max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, + enable_prompt_embeds=True, ) seq_lens: list[int] = [] @@ -179,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, + enable_prompt_embeds=True, ) context_lens: list[int] = [] @@ -359,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds, max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=True, + enable_prompt_embeds=True, ) # Add prefill requests. diff --git a/vllm/config.py b/vllm/config.py index 1ae8673f7775..b06a1401b4e0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -321,6 +321,9 @@ class ModelConfig: """Skip initialization of tokenizer and detokenizer. Expects valid `prompt_token_ids` and `None` for prompt from the input. The generated output will contain token ids.""" + enable_prompt_embeds: bool = False + """If `True`, enables passing text embeddings as inputs via the + `prompt_embeds` key.""" served_model_name: Optional[Union[str, list[str]]] = None """The model name(s) used in the API. If multiple names are provided, the server will respond to any of the provided names. The model name in the diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index aefba620e189..1ea95a57469c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -234,6 +234,7 @@ class EngineArgs: hf_config_path: Optional[str] = ModelConfig.hf_config_path task: TaskOption = ModelConfig.task skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init + enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path @@ -874,6 +875,7 @@ def create_model_config(self) -> ModelConfig: disable_sliding_window=self.disable_sliding_window, disable_cascade_attn=self.disable_cascade_attn, skip_tokenizer_init=self.skip_tokenizer_init, + enable_prompt_embeds=self.enable_prompt_embeds, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, use_async_output_proc=not self.disable_async_output_proc, diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 97a2ce5c615e..53e0a477a12d 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -303,8 +303,11 @@ def _process_embeds( self, parsed_content: EmbedsPrompt, ) -> EmbedsInputs: + if not self.model_config.enable_prompt_embeds: + raise ValueError("You must set `--enable-prompt-embeds` to input " + "`prompt_embeds`.") if envs.VLLM_USE_V1: - raise ValueError("prompt_embeds is only available in V0.") + raise ValueError("`prompt_embeds` is only available in V0.") prompt_embeds = parsed_content["prompt_embeds"] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 85814e9af9e3..e22bbcc656ff 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1565,7 +1565,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # product. cudagraph_capture_sizes = self.vllm_config.compilation_config\ .cudagraph_capture_sizes - cudagraph_inputs_embeds = (True, False) + cudagraph_inputs_embeds = (( + True, False) if self.model_config.enable_prompt_embeds else + (False, )) compilation_cases = itertools.product( cudagraph_capture_sizes, cudagraph_inputs_embeds, From a369bbf01a5f763d80b20deabdd3a87859dbd121 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 3 May 2025 05:19:26 +0000 Subject: [PATCH 2/3] Update doc Signed-off-by: DarkLight1337 --- vllm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index b06a1401b4e0..91ef9dcdbd56 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -323,7 +323,8 @@ class ModelConfig: output will contain token ids.""" enable_prompt_embeds: bool = False """If `True`, enables passing text embeddings as inputs via the - `prompt_embeds` key.""" + `prompt_embeds` key. Note that enabling this will double the time required + for graph compilation.""" served_model_name: Optional[Union[str, list[str]]] = None """The model name(s) used in the API. If multiple names are provided, the server will respond to any of the provided names. The model name in the From 30d3dadb4877cc4dd9ccd64cfcb1e92ee9db858a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 3 May 2025 05:28:29 +0000 Subject: [PATCH 3/3] Update CLI Signed-off-by: DarkLight1337 --- vllm/engine/arg_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1ea95a57469c..f6f8fb69fb70 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -446,6 +446,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["disable_cascade_attn"]) model_group.add_argument("--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]) + model_group.add_argument("--enable-prompt-embeds", + **model_kwargs["enable_prompt_embeds"]) model_group.add_argument("--served-model-name", **model_kwargs["served_model_name"]) # This one is a special case because it is the