Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/models/pooling_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ Our [OpenAI-Compatible Server](#openai-compatible-server) provides endpoints tha

- [Pooling API](#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
- [Embeddings API](#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](#multimodal-inputs) for embedding models.
- [Classification API](#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models.
- [Score API](#score-api) is similar to `LLM.score` for cross-encoder models.

## Matryoshka Embeddings
Expand Down
126 changes: 126 additions & 0 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ In addition, we have the following custom APIs:
- Applicable to any model with a tokenizer.
- [Pooling API](#pooling-api) (`/pooling`)
- Applicable to all [pooling models](../models/pooling_models.md).
- [Classification API](#classification-api) (`/classify`)
- Only applicable to [classification models](../models/pooling_models.md) (`--task classify`).
- [Score API](#score-api) (`/score`)
- Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`).
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
Expand Down Expand Up @@ -443,6 +445,130 @@ The input format is the same as [Embeddings API](#embeddings-api), but the outpu

Code example: <gh-file:examples/online_serving/openai_pooling_client.py>

(classification-api)=

### Classification API

Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).

We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.

Code example: <gh-file:examples/online_serving/openai_classification_client.py>

#### Example Requests

You can classify multiple texts by passing an array of strings:

Request:

```bash
curl -v "http://127.0.0.1:8000/classify" \
-H "Content-Type: application/json" \
-d '{
"model": "jason9693/Qwen2.5-1.5B-apeach",
"input": [
"Loved the new café—coffee was great.",
"This update broke everything. Frustrating."
]
}'
```

Response:

```bash
{
"id": "classify-7c87cac407b749a6935d8c7ce2a8fba2",
"object": "list",
"created": 1745383065,
"model": "jason9693/Qwen2.5-1.5B-apeach",
"data": [
{
"index": 0,
"label": "Default",
"probs": [
0.565970778465271,
0.4340292513370514
],
"num_classes": 2
},
{
"index": 1,
"label": "Spoiled",
"probs": [
0.26448777318000793,
0.7355121970176697
],
"num_classes": 2
}
],
"usage": {
"prompt_tokens": 20,
"total_tokens": 20,
"completion_tokens": 0,
"prompt_tokens_details": null
}
}
```

You can also pass a string directly to the `input` field:

Request:

```bash
curl -v "http://127.0.0.1:8000/classify" \
-H "Content-Type: application/json" \
-d '{
"model": "jason9693/Qwen2.5-1.5B-apeach",
"input": "Loved the new café—coffee was great."
}'
```

Response:

```bash
{
"id": "classify-9bf17f2847b046c7b2d5495f4b4f9682",
"object": "list",
"created": 1745383213,
"model": "jason9693/Qwen2.5-1.5B-apeach",
"data": [
{
"index": 0,
"label": "Default",
"probs": [
0.565970778465271,
0.4340292513370514
],
"num_classes": 2
}
],
"usage": {
"prompt_tokens": 10,
"total_tokens": 10,
"completion_tokens": 0,
"prompt_tokens_details": null
}
}
```

#### Extra parameters

The following [pooling parameters](#pooling-params) are supported.

:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-classification-pooling-params
:end-before: end-classification-pooling-params
:::

The following extra parameters are supported:

:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-classification-extra-params
:end-before: end-classification-extra-params
:::

(score-api)=

### Score API
Expand Down
49 changes: 49 additions & 0 deletions examples/online_serving/openai_classification_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0

import argparse
import pprint

import requests


def post_http_request(payload: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=payload)
return response


def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument("--host", type=str, default="localhost")
parse.add_argument("--port", type=int, default=8000)
parse.add_argument("--model",
type=str,
default="jason9693/Qwen2.5-1.5B-apeach")
return parse.parse_args()


def main(args):
host = args.host
port = args.port
model_name = args.model

api_url = f"http://{host}:{port}/classify"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

payload = {
"model": model_name,
"input": prompts,
}

classify_response = post_http_request(payload=payload, api_url=api_url)
pprint.pprint(classify_response.json())


if __name__ == "__main__":
args = parse_args()
main(args)
156 changes: 156 additions & 0 deletions tests/entrypoints/openai/test_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import requests

from vllm.entrypoints.openai.protocol import ClassificationResponse

from ...utils import RemoteOpenAIServer

MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
DTYPE = "float32" # Use float32 to avoid NaN issue


@pytest.fixture(scope="module")
def server():
args = [
"--enforce-eager",
"--max-model-len",
"512",
"--dtype",
DTYPE,
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_single_input_classification(server: RemoteOpenAIServer,
model_name: str):
input_text = "This product was excellent and exceeded my expectations"

classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": input_text
},
)

classification_response.raise_for_status()
output = ClassificationResponse.model_validate(
classification_response.json())

assert output.object == "list"
assert output.model == MODEL_NAME
assert len(output.data) == 1
assert hasattr(output.data[0], "label")
assert hasattr(output.data[0], "probs")


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_multiple_inputs_classification(server: RemoteOpenAIServer,
model_name: str):
input_texts = [
"The product arrived on time and works perfectly",
"I'm very satisfied with my purchase, would buy again",
"The customer service was helpful and resolved my issue quickly",
"This product broke after one week, terrible quality",
"I'm very disappointed with this purchase, complete waste of money",
"The customer service was rude and unhelpful",
]

classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": input_texts
},
)
output = ClassificationResponse.model_validate(
classification_response.json())

assert len(output.data) == len(input_texts)
for i, item in enumerate(output.data):
assert item.index == i
assert hasattr(item, "label")
assert hasattr(item, "probs")
assert len(item.probs) == item.num_classes
assert item.label in ["Default", "Spoiled"]


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str):
long_text = "hello " * 600

classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": long_text,
"truncate_prompt_tokens": 5
},
)

classification_response.raise_for_status()
output = ClassificationResponse.model_validate(
classification_response.json())

assert len(output.data) == 1
assert output.data[0].index == 0
assert hasattr(output.data[0], "probs")
assert output.usage.prompt_tokens == 5
assert output.usage.total_tokens == 5


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer,
model_name: str):
classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": "test",
"truncate_prompt_tokens": 513
},
)

error = classification_response.json()
assert classification_response.status_code == 400
assert error["object"] == "error"
assert "truncate_prompt_tokens" in error["message"]


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": ""
},
)

error = classification_response.json()
assert classification_response.status_code == 400
assert error["object"] == "error"


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_batch_classification_empty_list(server: RemoteOpenAIServer,
model_name: str):
classification_response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"input": []
},
)
classification_response.raise_for_status()
output = ClassificationResponse.model_validate(
classification_response.json())

assert output.object == "list"
assert isinstance(output.data, list)
assert len(output.data) == 0
Loading