Skip to content

Commit 887d7af

Browse files
[Core] Gate prompt_embeds behind a feature flag (#17607)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent a928424 commit 887d7af

File tree

8 files changed

+84
-33
lines changed

8 files changed

+84
-33
lines changed

tests/engine/test_options.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from contextlib import nullcontext
3+
4+
import pytest
5+
6+
from vllm.entrypoints.llm import LLM
7+
from vllm.sampling_params import SamplingParams
8+
9+
10+
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
11+
def test_skip_tokenizer_initialization(model: str):
12+
# This test checks if the flag skip_tokenizer_init skips the initialization
13+
# of tokenizer and detokenizer. The generated output is expected to contain
14+
# token ids.
15+
llm = LLM(
16+
model=model,
17+
skip_tokenizer_init=True,
18+
enforce_eager=True,
19+
)
20+
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
21+
22+
with pytest.raises(ValueError, match="cannot pass text prompts when"):
23+
llm.generate("abc", sampling_params)
24+
25+
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
26+
sampling_params=sampling_params)
27+
assert len(outputs) > 0
28+
completions = outputs[0].outputs
29+
assert len(completions) > 0
30+
assert completions[0].text == ""
31+
assert completions[0].token_ids
32+
33+
34+
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
35+
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
36+
def test_enable_prompt_embeds(hf_runner, model: str,
37+
enable_prompt_embeds: bool):
38+
prompt = "abc"
39+
40+
with hf_runner(model) as hf_model:
41+
token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids
42+
token_ids = token_ids.to(hf_model.model.device)
43+
44+
embed_layer = hf_model.model.get_input_embeddings()
45+
prompt_embeds = embed_layer(token_ids).squeeze(0)
46+
47+
ctx = (nullcontext() if enable_prompt_embeds else pytest.raises(
48+
ValueError, match="set `--enable-prompt-embeds`"))
49+
50+
# This test checks if the flag skip_tokenizer_init skips the initialization
51+
# of tokenizer and detokenizer. The generated output is expected to contain
52+
# token ids.
53+
llm = LLM(
54+
model=model,
55+
enable_prompt_embeds=enable_prompt_embeds,
56+
enforce_eager=True,
57+
)
58+
59+
with ctx:
60+
llm.generate({"prompt_embeds": prompt_embeds})

tests/engine/test_skip_tokenizer_init.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

tests/models/language/generation/test_common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,15 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
109109
# in parts of the operators
110110
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
111111

112+
use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0"
113+
112114
with hf_runner(model) as hf_model:
113115
hf_outputs = hf_model.generate_greedy_logprobs_limit(
114116
example_prompts, max_tokens, num_logprobs)
115117

116-
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
117-
"VLLM_USE_V1") == "0" else None
118+
prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds
119+
else None)
120+
118121
prompt_token_ids = []
119122
for prompt in example_prompts:
120123
token_ids = hf_model.tokenizer(prompt,
@@ -131,6 +134,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
131134
tokenizer_mode=model_info.tokenizer_mode,
132135
trust_remote_code=model_info.trust_remote_code,
133136
max_num_seqs=2,
137+
enable_prompt_embeds=use_prompt_embeds,
134138
) as vllm_model:
135139
vllm_outputs = vllm_model.generate_greedy_logprobs(
136140
example_prompts, max_tokens, num_logprobs)

tests/worker/test_model_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
4343
max_num_batched_tokens=100000,
4444
max_num_seqs=100000,
4545
enable_chunked_prefill=False,
46+
enable_prompt_embeds=True,
4647
)
4748

4849
seq_lens: list[int] = []
@@ -179,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
179180
max_num_batched_tokens=100000,
180181
max_num_seqs=100000,
181182
enable_chunked_prefill=False,
183+
enable_prompt_embeds=True,
182184
)
183185

184186
context_lens: list[int] = []
@@ -359,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
359361
max_num_batched_tokens=100000,
360362
max_num_seqs=100000,
361363
enable_chunked_prefill=True,
364+
enable_prompt_embeds=True,
362365
)
363366

364367
# Add prefill requests.

vllm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ class ModelConfig:
321321
"""Skip initialization of tokenizer and detokenizer. Expects valid
322322
`prompt_token_ids` and `None` for prompt from the input. The generated
323323
output will contain token ids."""
324+
enable_prompt_embeds: bool = False
325+
"""If `True`, enables passing text embeddings as inputs via the
326+
`prompt_embeds` key. Note that enabling this will double the time required
327+
for graph compilation."""
324328
served_model_name: Optional[Union[str, list[str]]] = None
325329
"""The model name(s) used in the API. If multiple names are provided, the
326330
server will respond to any of the provided names. The model name in the

vllm/engine/arg_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ class EngineArgs:
234234
hf_config_path: Optional[str] = ModelConfig.hf_config_path
235235
task: TaskOption = ModelConfig.task
236236
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
237+
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
237238
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
238239
trust_remote_code: bool = ModelConfig.trust_remote_code
239240
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
@@ -445,6 +446,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
445446
**model_kwargs["disable_cascade_attn"])
446447
model_group.add_argument("--skip-tokenizer-init",
447448
**model_kwargs["skip_tokenizer_init"])
449+
model_group.add_argument("--enable-prompt-embeds",
450+
**model_kwargs["enable_prompt_embeds"])
448451
model_group.add_argument("--served-model-name",
449452
**model_kwargs["served_model_name"])
450453
# This one is a special case because it is the
@@ -874,6 +877,7 @@ def create_model_config(self) -> ModelConfig:
874877
disable_sliding_window=self.disable_sliding_window,
875878
disable_cascade_attn=self.disable_cascade_attn,
876879
skip_tokenizer_init=self.skip_tokenizer_init,
880+
enable_prompt_embeds=self.enable_prompt_embeds,
877881
served_model_name=self.served_model_name,
878882
limit_mm_per_prompt=self.limit_mm_per_prompt,
879883
use_async_output_proc=not self.disable_async_output_proc,

vllm/inputs/preprocess.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,11 @@ def _process_embeds(
303303
self,
304304
parsed_content: EmbedsPrompt,
305305
) -> EmbedsInputs:
306+
if not self.model_config.enable_prompt_embeds:
307+
raise ValueError("You must set `--enable-prompt-embeds` to input "
308+
"`prompt_embeds`.")
306309
if envs.VLLM_USE_V1:
307-
raise ValueError("prompt_embeds is only available in V0.")
310+
raise ValueError("`prompt_embeds` is only available in V0.")
308311

309312
prompt_embeds = parsed_content["prompt_embeds"]
310313

vllm/worker/model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1565,7 +1565,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
15651565
# product.
15661566
cudagraph_capture_sizes = self.vllm_config.compilation_config\
15671567
.cudagraph_capture_sizes
1568-
cudagraph_inputs_embeds = (True, False)
1568+
cudagraph_inputs_embeds = ((
1569+
True, False) if self.model_config.enable_prompt_embeds else
1570+
(False, ))
15691571
compilation_cases = itertools.product(
15701572
cudagraph_capture_sizes,
15711573
cudagraph_inputs_embeds,

0 commit comments

Comments
 (0)