Skip to content

Commit 42fce02

Browse files
committed
Score API
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent a71e476 commit 42fce02

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

tests/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,10 @@ def classify(self, prompts: list[str]) -> list[str]:
454454
# output is final logits
455455
all_inputs = self.get_inputs(prompts)
456456
outputs = []
457+
problem_type = getattr(self.config, "problem_type", "")
458+
457459
for inputs in all_inputs:
458460
output = self.model(**self.wrap_device(inputs))
459-
460-
problem_type = getattr(self.config, "problem_type", "")
461-
462461
if problem_type == "regression":
463462
logits = output.logits[0].tolist()
464463
elif problem_type == "multi_label_classification":

tests/entrypoints/openai/test_classification.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,33 @@ def test_pooling(server: RemoteOpenAIServer, model_name: str):
226226
},
227227
)
228228
assert response.json()["error"]["type"] == "BadRequestError"
229+
230+
231+
@pytest.mark.asyncio
232+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
233+
def test_score(server: RemoteOpenAIServer, model_name: str):
234+
# score api is only enabled for num_labels == 1.
235+
response = requests.post(
236+
server.url_for("score"),
237+
json={
238+
"model": model_name,
239+
"text_1": "ping",
240+
"text_2": "pong",
241+
},
242+
)
243+
assert response.json()["error"]["type"] == "BadRequestError"
244+
245+
246+
@pytest.mark.asyncio
247+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
248+
def test_rerank(server: RemoteOpenAIServer, model_name: str):
249+
# rerank api is only enabled for num_labels == 1.
250+
response = requests.post(
251+
server.url_for("rerank"),
252+
json={
253+
"model": model_name,
254+
"query": "ping",
255+
"documents": ["pong"],
256+
},
257+
)
258+
assert response.json()["error"]["type"] == "BadRequestError"

vllm/entrypoints/openai/api_server.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,16 +1797,12 @@ async def init_app_state(
17971797
state.openai_serving_models,
17981798
request_logger=request_logger,
17991799
) if "classify" in supported_tasks else None
1800-
1801-
enable_serving_reranking = ("classify" in supported_tasks and getattr(
1802-
model_config.hf_config, "num_labels", 0) == 1)
18031800
state.openai_serving_scores = ServingScores(
18041801
engine_client,
18051802
model_config,
18061803
state.openai_serving_models,
18071804
request_logger=request_logger,
1808-
) if ("embed" in supported_tasks or enable_serving_reranking) else None
1809-
1805+
) if ("embed" in supported_tasks or "score" in supported_tasks) else None
18101806
state.openai_serving_tokenization = OpenAIServingTokenization(
18111807
engine_client,
18121808
model_config,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,13 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]:
13021302
"Please turn off chunked prefill by "
13031303
"`--no-enable-chunked-prefill` before using it.")
13041304

1305+
if "score" in supported_tasks:
1306+
num_labels = getattr(
1307+
self.model_config.hf_config, "num_labels", 0)
1308+
if num_labels != 1:
1309+
supported_tasks.remove("score")
1310+
logger.info_once("Score API is only enabled for num_labels == 1.")
1311+
13051312
return supported_tasks
13061313

13071314
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:

0 commit comments

Comments
 (0)