Skip to content

Commit d0078ec

Browse files
committed
+ ner
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 78818dd commit d0078ec

File tree

4 files changed

+106
-0
lines changed

4 files changed

+106
-0
lines changed

examples/offline_inference/ner.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from argparse import Namespace
5+
6+
from vllm import LLM, EngineArgs
7+
from vllm.utils import FlexibleArgumentParser
8+
9+
10+
def parse_args():
11+
parser = FlexibleArgumentParser()
12+
parser = EngineArgs.add_cli_args(parser)
13+
# Set example specific arguments
14+
parser.set_defaults(
15+
model="boltuix/NeuroBERT-NER",
16+
runner="pooling",
17+
enforce_eager=True,
18+
trust_remote_code=True,
19+
)
20+
return parser.parse_args()
21+
22+
23+
def main(args: Namespace):
24+
# Sample prompts.
25+
prompts = ["Barack Obama visited Microsoft headquarters in Seattle on January 2025."]
26+
27+
# Create an LLM.
28+
# You should pass runner="pooling" for reward models
29+
llm = LLM(**vars(args))
30+
31+
tokenizer = llm.get_tokenizer()
32+
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
33+
34+
outputs = llm.reward(prompts)
35+
for prompt, output in zip(prompts, outputs):
36+
logits = output.outputs.data
37+
predictions = logits.argmax(dim=-1)
38+
39+
# Map predictions to labels
40+
tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids)
41+
labels = [label_map[p.item()] for p in predictions]
42+
43+
# Print results
44+
for token, label in zip(tokens, labels):
45+
if token not in tokenizer.all_special_tokens:
46+
print(f"{token:15}{label}")
47+
48+
49+
if __name__ == "__main__":
50+
args = parse_args()
51+
main(args)

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def check_available_online(
417417

418418
# [Cross-encoder]
419419
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
420+
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
420421
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
421422
trust_remote_code=True,
422423
hf_overrides={

vllm/model_executor/models/bert.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,3 +611,56 @@ def forward(
611611
positions=positions,
612612
inputs_embeds=inputs_embeds,
613613
intermediate_tensors=intermediate_tensors)
614+
615+
616+
@default_pooling_type("ALL")
617+
class BertForTokenClassification(nn.Module, SupportsCrossEncoding,
618+
SupportsQuant):
619+
is_pooling_model = True
620+
621+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
622+
super().__init__()
623+
config = vllm_config.model_config.hf_config
624+
self.head_dtype = vllm_config.model_config.head_dtype
625+
self.num_labels = config.num_labels
626+
self.bert = BertModel(vllm_config=vllm_config,
627+
prefix=maybe_prefix(prefix, "bert"),
628+
embedding_class=BertEmbedding)
629+
self.classifier = nn.Linear(config.hidden_size,
630+
config.num_labels,
631+
dtype=self.head_dtype)
632+
633+
pooler_config = vllm_config.model_config.pooler_config
634+
assert pooler_config is not None
635+
636+
self.pooler = DispatchPooler({
637+
"encode":
638+
Pooler.for_encode(pooler_config),
639+
})
640+
641+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
642+
loader = AutoWeightsLoader(self)
643+
loaded_params = loader.load_weights(weights)
644+
return loaded_params
645+
646+
def forward(
647+
self,
648+
input_ids: Optional[torch.Tensor],
649+
positions: torch.Tensor,
650+
intermediate_tensors: Optional[IntermediateTensors] = None,
651+
inputs_embeds: Optional[torch.Tensor] = None,
652+
token_type_ids: Optional[torch.Tensor] = None,
653+
) -> torch.Tensor:
654+
655+
if token_type_ids is not None:
656+
assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
657+
assert input_ids is not None
658+
_encode_token_type_ids(input_ids, token_type_ids)
659+
660+
hidden_states = self.bert(input_ids=input_ids,
661+
positions=positions,
662+
inputs_embeds=inputs_embeds,
663+
intermediate_tensors=intermediate_tensors)
664+
665+
hidden_states = hidden_states.to(self.head_dtype)
666+
return self.classifier(hidden_states)

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@
196196

197197
_CROSS_ENCODER_MODELS = {
198198
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
199+
"BertForTokenClassification": ("bert", "BertForTokenClassification"),
199200
"GteNewForSequenceClassification": ("bert_with_rope",
200201
"GteNewForSequenceClassification"),
201202
"ModernBertForSequenceClassification": ("modernbert",

0 commit comments

Comments
 (0)