Skip to content

Commit aebda1b

Browse files
noooopIsotr0py
authored andcommitted
[New Model] Support BertForTokenClassification / Named Entity Recognition (NER) task (vllm-project#24872)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: charlifu <charlifu@amd.com>
1 parent ad12e8b commit aebda1b

File tree

11 files changed

+257
-2
lines changed

11 files changed

+257
-2
lines changed

docs/models/supported_models.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,17 @@ If your model is not in the above list, we will try to automatically convert the
554554
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
555555
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
556556

557+
#### Token Classification
558+
559+
These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) API.
560+
561+
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
562+
|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
563+
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ |
564+
565+
!!! note
566+
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.
567+
557568
[](){ #supported-mm-models }
558569

559570
## List of Multimodal Language Models

examples/offline_inference/pooling/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,14 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
2626
python examples/offline_inference/pooling/embed_matryoshka_fy.py
2727
```
2828

29+
## Named Entity Recognition (NER) usage
30+
31+
```bash
32+
python examples/offline_inference/pooling/ner.py
33+
```
34+
2935
## Qwen3 reranker usage
3036

3137
```bash
32-
python qwen3_reranker.py
38+
python examples/offline_inference/pooling/qwen3_reranker.py
3339
```
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
4+
5+
from argparse import Namespace
6+
7+
from vllm import LLM, EngineArgs
8+
from vllm.utils import FlexibleArgumentParser
9+
10+
11+
def parse_args():
12+
parser = FlexibleArgumentParser()
13+
parser = EngineArgs.add_cli_args(parser)
14+
# Set example specific arguments
15+
parser.set_defaults(
16+
model="boltuix/NeuroBERT-NER",
17+
runner="pooling",
18+
enforce_eager=True,
19+
trust_remote_code=True,
20+
)
21+
return parser.parse_args()
22+
23+
24+
def main(args: Namespace):
25+
# Sample prompts.
26+
prompts = [
27+
"Barack Obama visited Microsoft headquarters in Seattle on January 2025."
28+
]
29+
30+
# Create an LLM.
31+
llm = LLM(**vars(args))
32+
tokenizer = llm.get_tokenizer()
33+
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
34+
35+
# Run inference
36+
outputs = llm.encode(prompts)
37+
38+
for prompt, output in zip(prompts, outputs):
39+
logits = output.outputs.data
40+
predictions = logits.argmax(dim=-1)
41+
42+
# Map predictions to labels
43+
tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids)
44+
labels = [label_map[p.item()] for p in predictions]
45+
46+
# Print results
47+
for token, label in zip(tokens, labels):
48+
if token not in tokenizer.all_special_tokens:
49+
print(f"{token:15}{label}")
50+
51+
52+
if __name__ == "__main__":
53+
args = parse_args()
54+
main(args)

examples/online_serving/pooling/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ python examples/online_serving/pooling/cohere_rerank_client.py
1212
python examples/online_serving/pooling/jinaai_rerank_client.py
1313
```
1414

15+
## Named Entity Recognition (NER) usage
16+
17+
```bash
18+
python examples/online_serving/pooling/ner.py
19+
```
20+
1521
## Openai chat embedding for multimodal usage
1622

1723
```bash
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
4+
5+
"""
6+
Example online usage of Pooling API for Named Entity Recognition (NER).
7+
8+
Run `vllm serve <model> --runner pooling`
9+
to start up the server in vLLM. e.g.
10+
11+
vllm serve boltuix/NeuroBERT-NER
12+
"""
13+
14+
import argparse
15+
16+
import requests
17+
import torch
18+
19+
20+
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
21+
headers = {"User-Agent": "Test Client"}
22+
response = requests.post(api_url, headers=headers, json=prompt)
23+
return response
24+
25+
26+
def parse_args():
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument("--host", type=str, default="localhost")
29+
parser.add_argument("--port", type=int, default=8000)
30+
parser.add_argument("--model", type=str, default="boltuix/NeuroBERT-NER")
31+
32+
return parser.parse_args()
33+
34+
35+
def main(args):
36+
from transformers import AutoConfig, AutoTokenizer
37+
38+
api_url = f"http://{args.host}:{args.port}/pooling"
39+
model_name = args.model
40+
41+
# Load tokenizer and config
42+
tokenizer = AutoTokenizer.from_pretrained(model_name)
43+
config = AutoConfig.from_pretrained(model_name)
44+
label_map = config.id2label
45+
46+
# Input text
47+
text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025."
48+
prompt = {"model": model_name, "input": text}
49+
50+
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
51+
52+
# Run inference
53+
output = pooling_response.json()["data"][0]
54+
logits = torch.tensor(output["data"])
55+
predictions = logits.argmax(dim=-1)
56+
inputs = tokenizer(text, return_tensors="pt")
57+
58+
# Map predictions to labels
59+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
60+
labels = [label_map[p.item()] for p in predictions]
61+
assert len(tokens) == len(predictions)
62+
63+
# Print results
64+
for token, label in zip(tokens, labels):
65+
if token not in tokenizer.all_special_tokens:
66+
print(f"{token:15}{label}")
67+
68+
69+
if __name__ == "__main__":
70+
args = parse_args()
71+
main(args)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
from transformers import AutoModelForTokenClassification
6+
7+
from tests.models.utils import softmax
8+
9+
10+
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"])
11+
# The float32 is required for this tiny model to pass the test.
12+
@pytest.mark.parametrize("dtype", ["float"])
13+
@torch.inference_mode
14+
def test_models(
15+
hf_runner,
16+
vllm_runner,
17+
example_prompts,
18+
model: str,
19+
dtype: str,
20+
) -> None:
21+
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
22+
vllm_outputs = vllm_model.encode(example_prompts)
23+
24+
with hf_runner(model,
25+
dtype=dtype,
26+
auto_cls=AutoModelForTokenClassification) as hf_model:
27+
tokenizer = hf_model.tokenizer
28+
hf_outputs = []
29+
for prompt in example_prompts:
30+
inputs = tokenizer([prompt], return_tensors="pt")
31+
inputs = hf_model.wrap_device(inputs)
32+
output = hf_model.model(**inputs)
33+
hf_outputs.append(softmax(output.logits[0]))
34+
35+
# check logits difference
36+
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
37+
hf_output = torch.tensor(hf_output).cpu().float()
38+
vllm_output = torch.tensor(vllm_output).cpu().float()
39+
assert torch.allclose(hf_output, vllm_output, 1e-2)

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ def check_available_online(
414414

415415
# [Cross-encoder]
416416
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
417+
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
417418
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
418419
trust_remote_code=True,
419420
hf_overrides={

vllm/entrypoints/llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,10 @@ def encode(
943943
considered legacy and may be deprecated in the future. You should
944944
instead pass them via the `inputs` parameter.
945945
"""
946+
947+
if self.supported_tasks == ["encode"] and pooling_task is None:
948+
pooling_task = "encode"
949+
946950
if pooling_task is None:
947951
if "embed" in self.supported_tasks:
948952
pooling_task = "embed"

vllm/model_executor/models/bert.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,3 +611,55 @@ def forward(
611611
positions=positions,
612612
inputs_embeds=inputs_embeds,
613613
intermediate_tensors=intermediate_tensors)
614+
615+
616+
@default_pooling_type("ALL")
617+
class BertForTokenClassification(nn.Module):
618+
is_pooling_model = True
619+
620+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
621+
super().__init__()
622+
config = vllm_config.model_config.hf_config
623+
self.head_dtype = vllm_config.model_config.head_dtype
624+
self.num_labels = config.num_labels
625+
self.bert = BertModel(vllm_config=vllm_config,
626+
prefix=maybe_prefix(prefix, "bert"),
627+
embedding_class=BertEmbedding)
628+
self.classifier = nn.Linear(config.hidden_size,
629+
config.num_labels,
630+
dtype=self.head_dtype)
631+
632+
pooler_config = vllm_config.model_config.pooler_config
633+
assert pooler_config is not None
634+
635+
self.pooler = DispatchPooler({
636+
"encode":
637+
Pooler.for_encode(pooler_config),
638+
})
639+
640+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
641+
loader = AutoWeightsLoader(self)
642+
loaded_params = loader.load_weights(weights)
643+
return loaded_params
644+
645+
def forward(
646+
self,
647+
input_ids: Optional[torch.Tensor],
648+
positions: torch.Tensor,
649+
intermediate_tensors: Optional[IntermediateTensors] = None,
650+
inputs_embeds: Optional[torch.Tensor] = None,
651+
token_type_ids: Optional[torch.Tensor] = None,
652+
) -> torch.Tensor:
653+
654+
if token_type_ids is not None:
655+
assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
656+
assert input_ids is not None
657+
_encode_token_type_ids(input_ids, token_type_ids)
658+
659+
hidden_states = self.bert(input_ids=input_ids,
660+
positions=positions,
661+
inputs_embeds=inputs_embeds,
662+
intermediate_tensors=intermediate_tensors)
663+
664+
hidden_states = hidden_states.to(self.head_dtype)
665+
return self.classifier(hidden_states)

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@
193193

194194
_CROSS_ENCODER_MODELS = {
195195
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
196+
"BertForTokenClassification": ("bert", "BertForTokenClassification"),
196197
"GteNewForSequenceClassification": ("bert_with_rope",
197198
"GteNewForSequenceClassification"),
198199
"ModernBertForSequenceClassification": ("modernbert",

0 commit comments

Comments
 (0)