Skip to content

Commit be4c84b

Browse files
noooopgemini-code-assist[bot]DarkLight1337
authored andcommitted
[Frontend][3/N] Improve all pooling task | Support binary embedding response (vllm-project#27066)
Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent 85ea95c commit be4c84b

File tree

12 files changed

+693
-232
lines changed

12 files changed

+693
-232
lines changed

examples/online_serving/pooling/README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66
python examples/online_serving/pooling/cohere_rerank_client.py
77
```
88

9-
## Embedding embed_dtype usage
9+
## Embedding requests base64 encoding_format usage
1010

1111
```bash
12-
python examples/online_serving/pooling/embedding_embed_dtype_client.py
12+
python examples/online_serving/pooling/embedding_requests_base64_client.py
13+
```
14+
15+
## Embedding requests bytes encoding_format usage
16+
17+
```bash
18+
python examples/online_serving/pooling/embedding_requests_bytes_client.py
1319
```
1420

1521
## Jinaai rerank usage

examples/online_serving/pooling/embedding_embed_dtype_client.py renamed to examples/online_serving/pooling/embedding_requests_base64_client.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
import requests
1313
import torch
1414

15-
from vllm.entrypoints.openai.protocol import EMBED_DTYPE_TO_TORCH_DTYPE
15+
from vllm.utils.serial_utils import (
16+
EMBED_DTYPE_TO_TORCH_DTYPE,
17+
ENDIANNESS,
18+
binary2tensor,
19+
)
1620

1721

1822
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
@@ -34,24 +38,25 @@ def main(args):
3438
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
3539
model_name = args.model
3640

37-
for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
38-
prompt = {
39-
"model": model_name,
40-
"input": "vLLM is great!",
41-
"encoding_format": "base64",
42-
"embed_dtype": embed_dtype,
43-
}
44-
response = post_http_request(prompt=prompt, api_url=api_url)
45-
46-
embedding = []
47-
for data in response.json()["data"]:
48-
embedding.append(
49-
torch.frombuffer(
50-
base64.b64decode(data["embedding"]), dtype=torch_dtype
51-
).to(torch.float32)
52-
)
53-
embedding = torch.cat(embedding)
54-
print(embed_dtype, embedding.shape)
41+
# The OpenAI client does not support the embed_dtype and endianness parameters.
42+
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
43+
for endianness in ENDIANNESS:
44+
prompt = {
45+
"model": model_name,
46+
"input": "vLLM is great!",
47+
"encoding_format": "base64",
48+
"embed_dtype": embed_dtype,
49+
"endianness": endianness,
50+
}
51+
response = post_http_request(prompt=prompt, api_url=api_url)
52+
53+
embedding = []
54+
for data in response.json()["data"]:
55+
binary = base64.b64decode(data["embedding"])
56+
tensor = binary2tensor(binary, (-1,), embed_dtype, endianness)
57+
embedding.append(tensor.to(torch.float32))
58+
embedding = torch.cat(embedding)
59+
print(embed_dtype, endianness, embedding.shape)
5560

5661

5762
if __name__ == "__main__":
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Example Python client for embedding API using vLLM API server
4+
NOTE:
5+
start a supported embeddings model server with `vllm serve`, e.g.
6+
vllm serve intfloat/e5-small
7+
"""
8+
9+
import argparse
10+
import json
11+
12+
import requests
13+
import torch
14+
15+
from vllm.utils.serial_utils import (
16+
EMBED_DTYPE_TO_TORCH_DTYPE,
17+
ENDIANNESS,
18+
MetadataItem,
19+
decode_pooling_output,
20+
)
21+
22+
23+
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
24+
headers = {"User-Agent": "Test Client"}
25+
response = requests.post(api_url, headers=headers, json=prompt)
26+
return response
27+
28+
29+
def parse_args():
30+
parser = argparse.ArgumentParser()
31+
parser.add_argument("--host", type=str, default="localhost")
32+
parser.add_argument("--port", type=int, default=8000)
33+
parser.add_argument("--model", type=str, default="intfloat/e5-small")
34+
35+
return parser.parse_args()
36+
37+
38+
def main(args):
39+
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
40+
model_name = args.model
41+
42+
# The OpenAI client does not support the bytes encoding_format.
43+
# The OpenAI client does not support the embed_dtype and endianness parameters.
44+
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
45+
for endianness in ENDIANNESS:
46+
prompt = {
47+
"model": model_name,
48+
"input": "vLLM is great!",
49+
"encoding_format": "bytes",
50+
"embed_dtype": embed_dtype,
51+
"endianness": endianness,
52+
}
53+
response = post_http_request(prompt=prompt, api_url=api_url)
54+
metadata = json.loads(response.headers["metadata"])
55+
body = response.content
56+
items = [MetadataItem(**x) for x in metadata["data"]]
57+
58+
embedding = decode_pooling_output(items=items, body=body)
59+
embedding = [x.to(torch.float32) for x in embedding]
60+
embedding = torch.cat(embedding)
61+
print(embed_dtype, endianness, embedding.shape)
62+
63+
64+
if __name__ == "__main__":
65+
args = parse_args()
66+
main(args)

tests/entrypoints/pooling/openai/test_embedding.py

Lines changed: 82 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import base64
5+
import json
56

67
import numpy as np
78
import openai
@@ -15,11 +16,17 @@
1516
from tests.models.utils import check_embeddings_close
1617
from tests.utils import RemoteOpenAIServer
1718
from vllm.entrypoints.openai.protocol import (
18-
EMBED_DTYPE_TO_TORCH_DTYPE,
1919
EmbeddingResponse,
2020
PoolingResponse,
2121
)
2222
from vllm.transformers_utils.tokenizer import get_tokenizer
23+
from vllm.utils.serial_utils import (
24+
EMBED_DTYPE_TO_TORCH_DTYPE,
25+
ENDIANNESS,
26+
MetadataItem,
27+
binary2tensor,
28+
decode_pooling_output,
29+
)
2330

2431
MODEL_NAME = "intfloat/multilingual-e5-small"
2532
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
@@ -250,8 +257,8 @@ async def test_batch_base64_embedding(
250257

251258
@pytest.mark.asyncio
252259
@pytest.mark.parametrize("model_name", [MODEL_NAME])
253-
async def test_base64_embed_dtype(
254-
hf_model, server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
260+
async def test_base64_embed_dtype_and_endianness(
261+
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
255262
):
256263
input_texts = [
257264
"The best thing about vLLM is that it supports many different models",
@@ -262,59 +269,100 @@ async def test_base64_embed_dtype(
262269
)
263270
float_data = [d.embedding for d in responses_float.data]
264271

265-
for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
266-
responses_base64 = requests.post(
267-
server.url_for("/v1/embeddings"),
268-
json={
269-
"model": model_name,
270-
"input": input_texts,
271-
"encoding_format": "base64",
272-
"embed_dtype": embed_dtype,
273-
},
274-
)
275-
276-
base64_data = []
277-
for data in responses_base64.json()["data"]:
278-
base64_data.append(
279-
torch.frombuffer(base64.b64decode(data["embedding"]), dtype=torch_dtype)
280-
.to(torch.float32)
281-
.tolist()
272+
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
273+
for endianness in ENDIANNESS:
274+
responses_base64 = requests.post(
275+
server.url_for("/v1/embeddings"),
276+
json={
277+
"model": model_name,
278+
"input": input_texts,
279+
"encoding_format": "base64",
280+
"embed_dtype": embed_dtype,
281+
"endianness": endianness,
282+
},
282283
)
283284

284-
check_embeddings_close(
285-
embeddings_0_lst=float_data,
286-
embeddings_1_lst=base64_data,
287-
name_0="float_data",
288-
name_1="base64_data",
289-
tol=1e-2,
290-
)
285+
base64_data = []
286+
for data in responses_base64.json()["data"]:
287+
binary = base64.b64decode(data["embedding"])
288+
tensor = binary2tensor(binary, (-1,), embed_dtype, endianness)
289+
base64_data.append(tensor.to(torch.float32).tolist())
290+
291+
check_embeddings_close(
292+
embeddings_0_lst=float_data,
293+
embeddings_1_lst=base64_data,
294+
name_0="float_data",
295+
name_1="base64_data",
296+
tol=1e-2,
297+
)
291298

292299

293300
@pytest.mark.asyncio
294301
@pytest.mark.parametrize("model_name", [MODEL_NAME])
295-
async def test_base64_embed_dtype_not_supported(
296-
hf_model, server: RemoteOpenAIServer, model_name: str
302+
async def test_bytes_embed_dtype_and_endianness(
303+
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
297304
):
298305
input_texts = [
299306
"The best thing about vLLM is that it supports many different models",
300307
]
301308

302-
bad_embed_dtype = "bad_embed_dtype"
309+
responses_float = await client.embeddings.create(
310+
input=input_texts, model=model_name, encoding_format="float"
311+
)
312+
float_data = [d.embedding for d in responses_float.data]
313+
314+
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
315+
for endianness in ENDIANNESS:
316+
responses_bytes = requests.post(
317+
server.url_for("/v1/embeddings"),
318+
json={
319+
"model": model_name,
320+
"input": input_texts,
321+
"encoding_format": "bytes",
322+
"embed_dtype": embed_dtype,
323+
"endianness": endianness,
324+
},
325+
)
326+
327+
metadata = json.loads(responses_bytes.headers["metadata"])
328+
body = responses_bytes.content
329+
items = [MetadataItem(**x) for x in metadata["data"]]
330+
331+
bytes_data = decode_pooling_output(items=items, body=body)
332+
bytes_data = [x.to(torch.float32).tolist() for x in bytes_data]
333+
334+
check_embeddings_close(
335+
embeddings_0_lst=float_data,
336+
embeddings_1_lst=bytes_data,
337+
name_0="float_data",
338+
name_1="bytes_data",
339+
tol=1e-2,
340+
)
341+
342+
343+
@pytest.mark.asyncio
344+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
345+
@pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"])
346+
async def test_params_not_supported(
347+
server: RemoteOpenAIServer, model_name: str, param_name: str
348+
):
349+
input_texts = [
350+
"The best thing about vLLM is that it supports many different models",
351+
]
303352

304353
responses_base64 = requests.post(
305354
server.url_for("/v1/embeddings"),
306355
json={
307356
"model": model_name,
308357
"input": input_texts,
309358
"encoding_format": "base64",
310-
"embed_dtype": bad_embed_dtype,
359+
param_name: f"bad_{param_name}",
311360
},
312361
)
313362

314363
assert responses_base64.status_code == 400
315-
assert responses_base64.json()["error"]["message"].startswith(
316-
f"embed_dtype={bad_embed_dtype!r} is not supported."
317-
)
364+
assert "literal_error" in responses_base64.json()["error"]["message"]
365+
assert f"bad_{param_name}" in responses_base64.json()["error"]["message"]
318366

319367

320368
@pytest.mark.asyncio

0 commit comments

Comments
 (0)