Skip to content

Commit c6d4cb7

Browse files
committed
Polishing
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 179ff0d commit c6d4cb7

File tree

7 files changed

+100
-24
lines changed

7 files changed

+100
-24
lines changed

docs/models/supported_models.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,16 @@ If your model is not in the above list, we will try to automatically convert the
554554
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
555555
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
556556

557+
#### Token Classification
558+
These models primarily support the [`LLM.encode`](./pooling_models.md#llmreward) API.
559+
560+
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
561+
|--------------|------------|-------------------|----------------------|---------------------------|---------------------|
562+
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | |
563+
564+
!!! note
565+
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.
566+
557567
[](){ #supported-mm-models }
558568

559569
## List of Multimodal Language Models

examples/offline_inference/pooling/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,11 @@ python examples/offline_inference/pooling/embed_matryoshka_fy.py
2929
## Qwen3 reranker usage
3030

3131
```bash
32-
python qwen3_reranker.py
32+
python examples/offline_inference/pooling/qwen3_reranker.py
3333
```
34+
35+
## Named Entity Recognition (NER) usage
36+
37+
```bash
38+
python examples/offline_inference/pooling/ner.py
39+
```

examples/offline_inference/ner.py renamed to examples/offline_inference/pooling/ner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
34

45
from argparse import Namespace
56

@@ -25,13 +26,13 @@ def main(args: Namespace):
2526
prompts = ["Barack Obama visited Microsoft headquarters in Seattle on January 2025."]
2627

2728
# Create an LLM.
28-
# You should pass runner="pooling" for reward models
2929
llm = LLM(**vars(args))
30-
3130
tokenizer = llm.get_tokenizer()
3231
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
3332

34-
outputs = llm.reward(prompts)
33+
# Run inference
34+
outputs = llm.encode(prompts)
35+
3536
for prompt, output in zip(prompts, outputs):
3637
logits = output.outputs.data
3738
predictions = logits.argmax(dim=-1)

examples/online_serving/pooling/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,9 @@ python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py
4141
```bash
4242
python examples/online_serving/pooling/openai_pooling_client.py
4343
```
44+
45+
## Named Entity Recognition (NER) usage
46+
47+
```bash
48+
python examples/online_serving/pooling/ner.py
49+
```

examples/online_serving/ner.py renamed to examples/online_serving/pooling/ner.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
4+
35
"""
4-
Example online usage of Pooling API.
6+
Example online usage of Pooling API for Named Entity Recognition (NER).
57
68
Run `vllm serve <model> --runner pooling`
79
to start up the server in vLLM. e.g.
@@ -10,7 +12,6 @@
1012
"""
1113

1214
import argparse
13-
import pprint
1415

1516
import requests
1617
import torch
@@ -36,32 +37,32 @@ def main(args):
3637
api_url = f"http://{args.host}:{args.port}/pooling"
3738
model_name = args.model
3839

39-
40+
# Load tokenizer and config
4041
tokenizer = AutoTokenizer.from_pretrained(model_name)
4142
config = AutoConfig.from_pretrained(model_name)
4243
label_map = config.id2label
4344

45+
# Input text
4446
text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025."
45-
4647
prompt = {"model": model_name, "input": text}
47-
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
4848

49-
outputs = pooling_response.json()["data"]
50-
51-
for output in outputs:
52-
logits = torch.tensor(output['data'])
53-
predictions = logits.argmax(dim=-1)
54-
55-
inputs = tokenizer(text, return_tensors="pt")
56-
57-
# Map predictions to labels
58-
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
59-
labels = [label_map[p.item()] for p in predictions]
49+
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
6050

61-
# Print results
62-
for token, label in zip(tokens, labels):
63-
if token not in tokenizer.all_special_tokens:
64-
print(f"{token:15}{label}")
51+
# Run inference
52+
output = pooling_response.json()["data"][0]
53+
logits = torch.tensor(output['data'])
54+
predictions = logits.argmax(dim=-1)
55+
inputs = tokenizer(text, return_tensors="pt")
56+
57+
# Map predictions to labels
58+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
59+
labels = [label_map[p.item()] for p in predictions]
60+
assert len(tokens) == len(predictions)
61+
62+
# Print results
63+
for token, label in zip(tokens, labels):
64+
if token not in tokenizer.all_special_tokens:
65+
print(f"{token:15}{label}")
6566

6667

6768
if __name__ == "__main__":
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
from transformers import AutoModelForTokenClassification
6+
7+
from tests.models.utils import softmax
8+
9+
10+
@pytest.mark.parametrize(
11+
"model",
12+
[
13+
pytest.param("boltuix/NeuroBERT-NER", ),
14+
],
15+
)
16+
@pytest.mark.parametrize("dtype", ["float"])
17+
@torch.inference_mode
18+
def test_models(
19+
hf_runner,
20+
vllm_runner,
21+
example_prompts,
22+
model: str,
23+
dtype: str,
24+
) -> None:
25+
# The float32 is required for this tiny model to pass the test.
26+
27+
with vllm_runner(model,
28+
max_model_len=None,
29+
dtype=dtype,
30+
enforce_eager=True) as vllm_model:
31+
vllm_outputs = vllm_model.encode(example_prompts)
32+
33+
with hf_runner(model,
34+
dtype=dtype,
35+
auto_cls=AutoModelForTokenClassification) as hf_model:
36+
tokenizer = hf_model.tokenizer
37+
hf_outputs = []
38+
for prompt in example_prompts:
39+
inputs = tokenizer([prompt], return_tensors="pt")
40+
inputs = hf_model.wrap_device(inputs)
41+
output = hf_model.model(**inputs)
42+
hf_outputs.append(softmax(output.logits[0]))
43+
44+
# check logits difference
45+
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
46+
hf_output = torch.tensor(hf_output).cpu().float()
47+
vllm_output = torch.tensor(vllm_output).cpu().float()
48+
assert torch.allclose(hf_output, vllm_output, 1e-2)

vllm/entrypoints/llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,10 @@ def encode(
932932
considered legacy and may be deprecated in the future. You should
933933
instead pass them via the `inputs` parameter.
934934
"""
935+
936+
if self.supported_tasks == ["encode"]:
937+
pooling_task = "encode"
938+
935939
if pooling_task is None:
936940
if "embed" in self.supported_tasks:
937941
pooling_task = "embed"

0 commit comments

Comments
 (0)