Skip to content

Commit dd06fe1

Browse files
committed
/pooling endpoint support all pooling tasks
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent fb5fdfa commit dd06fe1

File tree

6 files changed

+223
-38
lines changed

6 files changed

+223
-38
lines changed

tests/entrypoints/pooling/openai/test_classification.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88

99
from tests.utils import RemoteOpenAIServer
10-
from vllm.entrypoints.openai.protocol import ClassificationResponse
10+
from vllm.entrypoints.openai.protocol import ClassificationResponse, PoolingResponse
1111

1212
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
1313
DTYPE = "float32" # Use float32 to avoid NaN issue
@@ -191,18 +191,7 @@ async def get_outputs(activation):
191191

192192
@pytest.mark.asyncio
193193
@pytest.mark.parametrize("model_name", [MODEL_NAME])
194-
def test_pooling(server: RemoteOpenAIServer, model_name: str):
195-
# pooling api uses ALL pooling, which does not support chunked prefill.
196-
response = requests.post(
197-
server.url_for("pooling"),
198-
json={"model": model_name, "input": "test", "encoding_format": "float"},
199-
)
200-
assert response.json()["error"]["type"] == "BadRequestError"
201-
202-
203-
@pytest.mark.asyncio
204-
@pytest.mark.parametrize("model_name", [MODEL_NAME])
205-
def test_score(server: RemoteOpenAIServer, model_name: str):
194+
async def test_score(server: RemoteOpenAIServer, model_name: str):
206195
# score api is only enabled for num_labels == 1.
207196
response = requests.post(
208197
server.url_for("score"),
@@ -217,7 +206,7 @@ def test_score(server: RemoteOpenAIServer, model_name: str):
217206

218207
@pytest.mark.asyncio
219208
@pytest.mark.parametrize("model_name", [MODEL_NAME])
220-
def test_rerank(server: RemoteOpenAIServer, model_name: str):
209+
async def test_rerank(server: RemoteOpenAIServer, model_name: str):
221210
# rerank api is only enabled for num_labels == 1.
222211
response = requests.post(
223212
server.url_for("rerank"),
@@ -228,3 +217,62 @@ def test_rerank(server: RemoteOpenAIServer, model_name: str):
228217
},
229218
)
230219
assert response.json()["error"]["type"] == "BadRequestError"
220+
221+
222+
@pytest.mark.asyncio
223+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
224+
async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
225+
input_text = "This product was excellent and exceeded my expectations"
226+
response = requests.post(
227+
server.url_for("pooling"),
228+
json={
229+
"model": model_name,
230+
"input": input_text,
231+
"encoding_format": "float",
232+
"task": "classify",
233+
},
234+
)
235+
poolings = PoolingResponse.model_validate(response.json())
236+
assert len(poolings.data) == 1
237+
assert len(poolings.data[0].data) == 2
238+
239+
240+
@pytest.mark.asyncio
241+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
242+
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
243+
# token_classify uses ALL pooling, which does not support chunked prefill.
244+
task = "token_classify"
245+
response = requests.post(
246+
server.url_for("pooling"),
247+
json={
248+
"model": model_name,
249+
"input": "test",
250+
"encoding_format": "float",
251+
"task": task,
252+
},
253+
)
254+
assert response.json()["error"]["type"] == "BadRequestError"
255+
assert response.json()["error"]["message"].startswith(
256+
f"Task {task} is not supported"
257+
)
258+
259+
260+
@pytest.mark.asyncio
261+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
262+
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
263+
async def test_pooling_not_supported(
264+
server: RemoteOpenAIServer, model_name: str, task: str
265+
):
266+
response = requests.post(
267+
server.url_for("pooling"),
268+
json={
269+
"model": model_name,
270+
"input": "test",
271+
"encoding_format": "float",
272+
"task": task,
273+
},
274+
)
275+
assert response.json()["error"]["type"] == "BadRequestError"
276+
assert response.json()["error"]["message"].startswith(
277+
f"Task {task} is not supported"
278+
)

tests/entrypoints/pooling/openai/test_embedding.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,16 +562,65 @@ async def get_outputs(normalize):
562562

563563
@pytest.mark.asyncio
564564
@pytest.mark.parametrize("model_name", [MODEL_NAME])
565-
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
565+
async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str):
566+
task = "embed"
566567
input_text = ["The chef prepared a delicious meal."]
567568

568569
response = requests.post(
569570
server.url_for("pooling"),
570-
json={"model": model_name, "input": input_text, "encoding_format": "float"},
571+
json={
572+
"model": model_name,
573+
"input": input_text,
574+
"encoding_format": "float",
575+
"task": task,
576+
},
577+
)
578+
579+
poolings = PoolingResponse.model_validate(response.json())
580+
581+
assert len(poolings.data) == 1
582+
assert len(poolings.data[0].data) == 384
583+
584+
585+
@pytest.mark.asyncio
586+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
587+
async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
588+
task = "token_embed"
589+
input_text = ["The chef prepared a delicious meal."]
590+
591+
response = requests.post(
592+
server.url_for("pooling"),
593+
json={
594+
"model": model_name,
595+
"input": input_text,
596+
"encoding_format": "float",
597+
"task": task,
598+
},
571599
)
572600

573601
poolings = PoolingResponse.model_validate(response.json())
574602

575603
assert len(poolings.data) == 1
576604
assert len(poolings.data[0].data) == 11
577605
assert len(poolings.data[0].data[0]) == 384
606+
607+
608+
@pytest.mark.asyncio
609+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
610+
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
611+
async def test_pooling_not_supported(
612+
server: RemoteOpenAIServer, model_name: str, task: str
613+
):
614+
response = requests.post(
615+
server.url_for("pooling"),
616+
json={
617+
"model": model_name,
618+
"input": "test",
619+
"encoding_format": "float",
620+
"task": task,
621+
},
622+
)
623+
assert response.json()["error"]["type"] == "BadRequestError"
624+
assert response.json()["error"]["message"].startswith(
625+
f"Task {task} is not supported"
626+
)

tests/entrypoints/pooling/openai/test_rerank.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,25 @@ async def get_outputs(activation):
163163

164164
@pytest.mark.asyncio
165165
@pytest.mark.parametrize("model_name", [MODEL_NAME])
166-
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
166+
async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
167+
input_text = "This product was excellent and exceeded my expectations"
168+
response = requests.post(
169+
server.url_for("pooling"),
170+
json={
171+
"model": model_name,
172+
"input": input_text,
173+
"encoding_format": "float",
174+
"task": "classify",
175+
},
176+
)
177+
poolings = PoolingResponse.model_validate(response.json())
178+
assert len(poolings.data) == 1
179+
assert len(poolings.data[0].data) == 1
180+
181+
182+
@pytest.mark.asyncio
183+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
184+
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
167185
input_text = ["The chef prepared a delicious meal."]
168186

169187
response = requests.post(
@@ -176,3 +194,24 @@ async def test_pooling(server: RemoteOpenAIServer, model_name: str):
176194
assert len(poolings.data) == 1
177195
assert len(poolings.data[0].data) == 11
178196
assert len(poolings.data[0].data[0]) == 1
197+
198+
199+
@pytest.mark.asyncio
200+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
201+
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
202+
async def test_pooling_not_supported(
203+
server: RemoteOpenAIServer, model_name: str, task: str
204+
):
205+
response = requests.post(
206+
server.url_for("pooling"),
207+
json={
208+
"model": model_name,
209+
"input": "test",
210+
"encoding_format": "float",
211+
"task": task,
212+
},
213+
)
214+
assert response.json()["error"]["type"] == "BadRequestError"
215+
assert response.json()["error"]["message"].startswith(
216+
f"Task {task} is not supported"
217+
)

vllm/entrypoints/openai/api_server.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,12 +1749,7 @@ async def init_app_state(
17491749
log_error_stack=args.log_error_stack,
17501750
)
17511751
)
1752-
if (
1753-
any(
1754-
task in supported_tasks
1755-
for task in ["token_embed", "token_classify", "plugin"]
1756-
)
1757-
)
1752+
if supported_tasks
17581753
else None
17591754
)
17601755
state.openai_serving_embedding = (

vllm/entrypoints/openai/protocol.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
)
5050
from openai_harmony import Message as OpenAIHarmonyMessage
5151

52+
from vllm.tasks import PoolingTask
5253
from vllm.utils.serial_utils import (
5354
EmbedDType,
5455
EncodingFormat,
@@ -1669,8 +1670,42 @@ def to_pooling_params(self):
16691670

16701671
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
16711672

1672-
PoolingCompletionRequest = EmbeddingCompletionRequest
1673-
PoolingChatRequest = EmbeddingChatRequest
1673+
1674+
class PoolingCompletionRequest(EmbeddingCompletionRequest):
1675+
task: PoolingTask | None = None
1676+
activation: bool | None = Field(
1677+
default=None,
1678+
description="Whether to use activation for classification outputs. "
1679+
"If it is a classify or token_classify task, the default is True; "
1680+
"for other tasks, this value should be None.",
1681+
)
1682+
1683+
def to_pooling_params(self):
1684+
return PoolingParams(
1685+
truncate_prompt_tokens=self.truncate_prompt_tokens,
1686+
dimensions=self.dimensions,
1687+
normalize=self.normalize,
1688+
activation=self.activation,
1689+
)
1690+
1691+
1692+
class PoolingChatRequest(EmbeddingChatRequest):
1693+
task: PoolingTask | None = None
1694+
activation: bool | None = Field(
1695+
default=None,
1696+
description="Whether to use activation for classification outputs. "
1697+
"If it is a classify or token_classify task, the default is True; "
1698+
"for other tasks, this value should be None.",
1699+
)
1700+
1701+
def to_pooling_params(self):
1702+
return PoolingParams(
1703+
truncate_prompt_tokens=self.truncate_prompt_tokens,
1704+
dimensions=self.dimensions,
1705+
normalize=self.normalize,
1706+
activation=self.activation,
1707+
)
1708+
16741709

16751710
T = TypeVar("T")
16761711

@@ -1686,6 +1721,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
16861721
"""
16871722
data: T
16881723

1724+
task: PoolingTask = "plugin"
16891725
encoding_format: EncodingFormat = "float"
16901726
embed_dtype: EmbedDType = Field(
16911727
default="float32",
@@ -1749,8 +1785,11 @@ class ScoreRequest(OpenAIBaseModel):
17491785
),
17501786
)
17511787

1752-
activation: bool | None = None
1753-
1788+
activation: bool | None = Field(
1789+
default=None,
1790+
description="Whether to use activation for classification outputs. "
1791+
"Default is True.",
1792+
)
17541793
# --8<-- [end:score-extra-params]
17551794

17561795
def to_pooling_params(self):
@@ -1783,8 +1822,11 @@ class RerankRequest(OpenAIBaseModel):
17831822
),
17841823
)
17851824

1786-
activation: bool | None = None
1787-
1825+
activation: bool | None = Field(
1826+
default=None,
1827+
description="Whether to use activation for classification outputs. "
1828+
"Default is True.",
1829+
)
17881830
# --8<-- [end:rerank-extra-params]
17891831

17901832
def to_pooling_params(self):
@@ -1958,8 +2000,11 @@ class ClassificationRequest(OpenAIBaseModel):
19582000
),
19592001
)
19602002

1961-
activation: bool | None = None
1962-
2003+
activation: bool | None = Field(
2004+
default=None,
2005+
description="Whether to use activation for classification outputs. "
2006+
"Default is True.",
2007+
)
19632008
# --8<-- [end:classification-extra-params]
19642009

19652010
def to_pooling_params(self):

vllm/entrypoints/openai/serving_pooling.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,24 @@ async def create_pooling(
170170
pooling_params = request.to_pooling_params()
171171

172172
pooling_task: PoolingTask
173-
if "token_embed" in self.supported_tasks:
174-
pooling_task = "token_embed"
175-
elif "token_classify" in self.supported_tasks:
176-
pooling_task = "token_classify"
177-
elif "plugin" in self.supported_tasks:
178-
pooling_task = "plugin"
173+
if request.task is None:
174+
if "token_embed" in self.supported_tasks:
175+
pooling_task = "token_embed"
176+
elif "token_classify" in self.supported_tasks:
177+
pooling_task = "token_classify"
178+
elif "plugin" in self.supported_tasks:
179+
pooling_task = "plugin"
180+
else:
181+
return self.create_error_response(
182+
f"pooling_task must be one of {self.supported_tasks}."
183+
)
179184
else:
185+
pooling_task = request.task
186+
187+
if pooling_task not in self.supported_tasks:
180188
return self.create_error_response(
181-
f"pooling_task must be one of {self.supported_tasks}."
189+
f"Task {pooling_task} is not supported, it"
190+
f" must be one of {self.supported_tasks}."
182191
)
183192

184193
try:

0 commit comments

Comments
 (0)