Skip to content

Commit a29add2

Browse files
committed
support input embedding
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 068c3a0 commit a29add2

File tree

5 files changed

+600
-39
lines changed

5 files changed

+600
-39
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ jobs:
127127
pytest -sv tests/singlecard/test_scheduler.py
128128
# guided decoding doesn't work, fix it later
129129
# pytest -sv tests/singlecard/test_guided_decoding.py.py
130-
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
130+
pytest -sv tests/singlecard/test_prompt_embedding.py
131+
pytest -sv tests/singlecard/test_ilama_lora.py
132+
pytest -sv tests/singlecard/test_pyhccl.py
131133
else
132134
pytest -sv tests/multicard/test_ilama_lora_tp2.py
133135
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error.
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
3+
PreTrainedTokenizer)
4+
from vllm import LLM
5+
6+
7+
def init_tokenizer_and_llm(model_name: str):
8+
tokenizer = AutoTokenizer.from_pretrained(model_name)
9+
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
10+
embedding_layer = transformers_model.get_input_embeddings()
11+
llm = LLM(model=model_name, enable_prompt_embeds=True)
12+
return tokenizer, embedding_layer, llm
13+
14+
15+
def get_prompt_embeds(chat: list[dict[str,
16+
str]], tokenizer: PreTrainedTokenizer,
17+
embedding_layer: torch.nn.Module):
18+
token_ids = tokenizer.apply_chat_template(chat,
19+
add_generation_prompt=True,
20+
return_tensors='pt')
21+
prompt_embeds = embedding_layer(token_ids).squeeze(0)
22+
return prompt_embeds
23+
24+
25+
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
26+
embedding_layer: torch.nn.Module):
27+
chat = [{
28+
"role": "user",
29+
"content": "Please tell me about the capital of France."
30+
}]
31+
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
32+
33+
outputs = llm.generate({
34+
"prompt_embeds": prompt_embeds,
35+
})
36+
37+
print("\n[Single Inference Output]")
38+
print("-" * 30)
39+
for o in outputs:
40+
print(o.outputs[0].text)
41+
print("-" * 30)
42+
43+
44+
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
45+
embedding_layer: torch.nn.Module):
46+
chats = [[{
47+
"role": "user",
48+
"content": "Please tell me about the capital of France."
49+
}],
50+
[{
51+
"role": "user",
52+
"content": "When is the day longest during the year?"
53+
}],
54+
[{
55+
"role": "user",
56+
"content": "Where is bigger, the moon or the sun?"
57+
}]]
58+
59+
prompt_embeds_list = [
60+
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
61+
]
62+
63+
outputs = llm.generate([{
64+
"prompt_embeds": embeds
65+
} for embeds in prompt_embeds_list])
66+
67+
print("\n[Batch Inference Outputs]")
68+
print("-" * 30)
69+
for i, o in enumerate(outputs):
70+
print(f"Q{i+1}: {chats[i][0]['content']}")
71+
print(f"A{i+1}: {o.outputs[0].text}\n")
72+
print("-" * 30)
73+
74+
75+
def main():
76+
model_name = "meta-llama/Llama-3.2-1B-Instruct"
77+
tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name)
78+
single_prompt_inference(llm, tokenizer, embedding_layer)
79+
batch_prompt_inference(llm, tokenizer, embedding_layer)
80+
81+
82+
if __name__ == "__main__":
83+
main()
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm/tests/entrypoints/openai/test_completion_with_prompt_embeds.py
18+
#
19+
import base64
20+
import io
21+
import os
22+
23+
import openai # use the official client for correctness check
24+
import pytest
25+
import pytest_asyncio
26+
import torch
27+
from modelscope import snapshot_download # type: ignore
28+
from openai import BadRequestError
29+
from transformers import AutoConfig
30+
from vllm.engine.arg_utils import EngineArgs
31+
32+
from tests.utils import RemoteOpenAIServer
33+
34+
if not hasattr(EngineArgs, "enable_prompt_embeds"):
35+
pytest.skip("Not supported vllm version", allow_module_level=True)
36+
37+
# any model with a chat template should work here
38+
MODEL_NAME = snapshot_download("LLM-Research/Llama-3.2-1B-Instruct")
39+
40+
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
41+
42+
43+
@pytest.fixture(scope="module")
44+
def default_server_args() -> list[str]:
45+
return [
46+
# use half precision for speed and memory savings in CI environment
47+
"--dtype",
48+
"bfloat16",
49+
"--max-model-len",
50+
"8192",
51+
"--max-num-seqs",
52+
"128",
53+
"--enforce-eager",
54+
# Prompt Embeds server args
55+
"--enable-prompt-embeds",
56+
"--no-enable-chunked-prefill",
57+
]
58+
59+
60+
@pytest.fixture(scope="module",
61+
params=["", "--disable-frontend-multiprocessing"])
62+
def server_with_prompt_embeds(default_server_args, request):
63+
if request.param:
64+
default_server_args.append(request.param)
65+
66+
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
67+
yield remote_server
68+
69+
70+
@pytest_asyncio.fixture
71+
async def client_with_prompt_embeds(server_with_prompt_embeds):
72+
async with server_with_prompt_embeds.get_async_client() as async_client:
73+
yield async_client
74+
75+
76+
def create_dummy_embeds(num_tokens: int = 5) -> str:
77+
"""Create dummy embeddings and return them as base64 encoded string."""
78+
dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size)
79+
buffer = io.BytesIO()
80+
torch.save(dummy_embeds, buffer)
81+
return base64.b64encode(buffer.getvalue()).decode('utf-8')
82+
83+
84+
@pytest.mark.asyncio
85+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
86+
@pytest.mark.skipif(
87+
os.getenv("VLLM_USE_V1") == "1",
88+
reason="Enable embedding input will fallback to v0, skip it")
89+
async def test_completions_with_prompt_embeds(
90+
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
91+
# Test case: Single prompt embeds input
92+
encoded_embeds = create_dummy_embeds()
93+
completion = await client_with_prompt_embeds.completions.create(
94+
model=model_name,
95+
prompt="", # Add empty prompt as required parameter
96+
max_tokens=5,
97+
temperature=0.0,
98+
extra_body={"prompt_embeds": encoded_embeds})
99+
assert len(completion.choices[0].text) >= 1
100+
assert completion.choices[0].prompt_logprobs is None
101+
102+
# Test case: batch completion with prompt_embeds
103+
encoded_embeds2 = create_dummy_embeds()
104+
completion = await client_with_prompt_embeds.completions.create(
105+
model=model_name,
106+
prompt="", # Add empty prompt as required parameter
107+
max_tokens=5,
108+
temperature=0.0,
109+
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
110+
assert len(completion.choices) == 2
111+
assert len(completion.choices[0].text) >= 1
112+
assert len(completion.choices[1].text) >= 1
113+
114+
# Test case: streaming with prompt_embeds
115+
encoded_embeds = create_dummy_embeds()
116+
single_completion = await client_with_prompt_embeds.completions.create(
117+
model=model_name,
118+
prompt="", # Add empty prompt as required parameter
119+
max_tokens=5,
120+
temperature=0.0,
121+
extra_body={"prompt_embeds": encoded_embeds})
122+
single_output = single_completion.choices[0].text
123+
124+
stream = await client_with_prompt_embeds.completions.create(
125+
model=model_name,
126+
prompt="", # Add empty prompt as required parameter
127+
max_tokens=5,
128+
temperature=0.0,
129+
stream=True,
130+
extra_body={"prompt_embeds": encoded_embeds})
131+
chunks = []
132+
finish_reason_count = 0
133+
async for chunk in stream:
134+
chunks.append(chunk.choices[0].text)
135+
if chunk.choices[0].finish_reason is not None:
136+
finish_reason_count += 1
137+
assert finish_reason_count == 1
138+
assert chunk.choices[0].finish_reason == "length"
139+
assert chunk.choices[0].text
140+
assert "".join(chunks) == single_output
141+
142+
# Test case: batch streaming with prompt_embeds
143+
encoded_embeds2 = create_dummy_embeds()
144+
stream = await client_with_prompt_embeds.completions.create(
145+
model=model_name,
146+
prompt="", # Add empty prompt as required parameter
147+
max_tokens=5,
148+
temperature=0.0,
149+
stream=True,
150+
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
151+
chunks_stream_embeds: list[list[str]] = [[], []]
152+
finish_reason_count = 0
153+
async for chunk in stream:
154+
chunks_stream_embeds[chunk.choices[0].index].append(
155+
chunk.choices[0].text)
156+
if chunk.choices[0].finish_reason is not None:
157+
finish_reason_count += 1
158+
assert finish_reason_count == 2
159+
assert chunk.choices[0].finish_reason == "length"
160+
assert chunk.choices[0].text
161+
assert len(chunks_stream_embeds[0]) > 0
162+
assert len(chunks_stream_embeds[1]) > 0
163+
164+
# Test case: mixed text and prompt_embeds
165+
encoded_embeds = create_dummy_embeds()
166+
completion_mixed = await client_with_prompt_embeds.completions.create(
167+
model=model_name,
168+
prompt="This is a prompt",
169+
max_tokens=5,
170+
temperature=0.0,
171+
extra_body={"prompt_embeds": encoded_embeds})
172+
assert len(completion.choices) == 2
173+
completion_text_only = await client_with_prompt_embeds.completions.create(
174+
model=model_name,
175+
prompt="This is a prompt",
176+
max_tokens=5,
177+
temperature=0.0,
178+
)
179+
completion_embeds_only = await client_with_prompt_embeds.completions.create(
180+
model=model_name,
181+
prompt="",
182+
max_tokens=5,
183+
temperature=0.0,
184+
extra_body={"prompt_embeds": encoded_embeds})
185+
# Embeddings responses should be handled first
186+
assert completion_mixed.choices[0].text == completion_embeds_only.choices[
187+
0].text
188+
assert completion_mixed.choices[1].text == completion_text_only.choices[
189+
0].text
190+
191+
192+
@pytest.mark.asyncio
193+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
194+
@pytest.mark.skipif(
195+
os.getenv("VLLM_USE_V1") == "1",
196+
reason="Enable embedding input will fallback to v0, skip it")
197+
async def test_completions_errors_with_prompt_embeds(
198+
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
199+
# Test error case: invalid prompt_embeds
200+
with pytest.raises(BadRequestError):
201+
await client_with_prompt_embeds.completions.create(
202+
prompt="",
203+
model=model_name,
204+
max_tokens=5,
205+
temperature=0.0,
206+
extra_body={"prompt_embeds": "invalid_base64"})
207+
208+
209+
@pytest.mark.asyncio
210+
@pytest.mark.parametrize("logprobs_arg", [1, 0])
211+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
212+
@pytest.mark.skipif(
213+
os.getenv("VLLM_USE_V1") == "1",
214+
reason="Enable embedding input will fallback to v0, skip it")
215+
async def test_completions_with_logprobs_and_prompt_embeds(
216+
client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int,
217+
model_name: str):
218+
# Test case: Logprobs using prompt_embeds
219+
encoded_embeds = create_dummy_embeds()
220+
completion = await client_with_prompt_embeds.completions.create(
221+
model=model_name,
222+
prompt="", # Add empty prompt as required parameter
223+
max_tokens=5,
224+
temperature=0.0,
225+
echo=False,
226+
logprobs=logprobs_arg,
227+
extra_body={"prompt_embeds": encoded_embeds})
228+
229+
logprobs = completion.choices[0].logprobs
230+
assert logprobs is not None
231+
assert len(logprobs.text_offset) == 5
232+
assert len(logprobs.token_logprobs) == 5
233+
assert len(logprobs.top_logprobs) == 5
234+
for top_logprobs in logprobs.top_logprobs[1:]:
235+
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
236+
assert len(logprobs.tokens) == 5
237+
238+
# Test case: Log probs with batch completion and prompt_embeds
239+
encoded_embeds2 = create_dummy_embeds()
240+
completion = await client_with_prompt_embeds.completions.create(
241+
model=model_name,
242+
prompt="", # Add empty prompt as required parameter
243+
max_tokens=5,
244+
temperature=0.0,
245+
echo=False,
246+
logprobs=logprobs_arg,
247+
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
248+
249+
assert len(completion.choices) == 2
250+
for choice in completion.choices:
251+
logprobs = choice.logprobs
252+
assert logprobs is not None
253+
assert len(logprobs.text_offset) == 5
254+
assert len(logprobs.token_logprobs) == 5
255+
assert len(logprobs.top_logprobs) == 5
256+
for top_logprobs in logprobs.top_logprobs[1:]:
257+
assert max(logprobs_arg,
258+
1) <= len(top_logprobs) <= logprobs_arg + 1
259+
assert len(logprobs.tokens) == 5

0 commit comments

Comments
 (0)