Skip to content

Commit 11a7faf

Browse files
authored
[New Model]: Support GteNewModelForSequenceClassification (#23524)
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 186aced commit 11a7faf

File tree

13 files changed

+157
-76
lines changed

13 files changed

+157
-76
lines changed

docs/models/supported_models.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
497497
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
498498
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
499499
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
500+
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ |
500501
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
501502
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
502503
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
@@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
513514
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
514515
```
515516

517+
!!! note
518+
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
519+
516520
!!! note
517521
Load the official original `mxbai-rerank-v2` by using the following command.
518522

tests/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,10 @@ def classify(self, prompts: list[str]) -> list[str]:
456456
# output is final logits
457457
all_inputs = self.get_inputs(prompts)
458458
outputs = []
459+
problem_type = getattr(self.config, "problem_type", "")
460+
459461
for inputs in all_inputs:
460462
output = self.model(**self.wrap_device(inputs))
461-
462-
problem_type = getattr(self.config, "problem_type", "")
463-
464463
if problem_type == "regression":
465464
logits = output.logits[0].tolist()
466465
elif problem_type == "multi_label_classification":

tests/models/language/pooling/embed_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner,
5151
vllm_extra_kwargs = vllm_extra_kwargs or {}
5252
vllm_extra_kwargs["dtype"] = model_info.dtype
5353

54+
if model_info.hf_overrides is not None:
55+
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
56+
5457
with vllm_runner(model_info.name,
5558
runner="pooling",
5659
max_model_len=None,

tests/models/language/pooling/mteb_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner,
172172
vllm_extra_kwargs = vllm_extra_kwargs or {}
173173
vllm_extra_kwargs["dtype"] = model_info.dtype
174174

175+
if model_info.hf_overrides is not None:
176+
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
177+
175178
with vllm_runner(model_info.name,
176179
runner="pooling",
177180
max_model_len=None,
@@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner,
284287
vllm_extra_kwargs = vllm_extra_kwargs or {}
285288
vllm_extra_kwargs["dtype"] = model_info.dtype
286289

290+
if model_info.hf_overrides is not None:
291+
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
292+
287293
with vllm_runner(model_info.name,
288294
runner="pooling",
289295
max_model_len=None,

tests/models/language/pooling/test_bge_reranker_v2_gemma.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@
1313

1414
RERANK_MODELS = [
1515
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
16-
architecture="GemmaForSequenceClassification"),
16+
architecture="GemmaForSequenceClassification",
17+
hf_overrides={
18+
"architectures":
19+
["GemmaForSequenceClassification"],
20+
"classifier_from_token": ["Yes"],
21+
"method":
22+
"no_post_processing",
23+
}),
1724
]
1825

1926
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
@@ -119,22 +126,9 @@ def predict(
119126

120127

121128
@pytest.mark.parametrize("model_info", RERANK_MODELS)
122-
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo,
123-
monkeypatch) -> None:
124-
monkeypatch.setenv("VLLM_USE_V1", "0")
125-
126-
assert model_info.architecture == "GemmaForSequenceClassification"
127-
128-
vllm_extra_kwargs: dict[str, Any] = {
129-
"hf_overrides": {
130-
"architectures": ["GemmaForSequenceClassification"],
131-
"classifier_from_token": ["Yes"],
132-
"method": "no_post_processing",
133-
}
134-
}
129+
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
135130

136131
mteb_test_rerank_models(GemmaRerankerHfRunner,
137132
vllm_runner,
138133
model_info,
139-
vllm_extra_kwargs,
140134
vllm_mteb_encoder=GemmaMtebEncoder)

tests/models/language/pooling/test_gte.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import Any
43

54
import pytest
65

@@ -33,12 +32,15 @@
3332
########### NewModel
3433
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
3534
architecture="GteNewModel",
35+
hf_overrides={"architectures": ["GteNewModel"]},
3636
enable_test=True),
3737
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
3838
architecture="GteNewModel",
39+
hf_overrides={"architectures": ["GteNewModel"]},
3940
enable_test=True),
4041
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
4142
architecture="GteNewModel",
43+
hf_overrides={"architectures": ["GteNewModel"]},
4244
enable_test=True),
4345
########### Qwen2ForCausalLM
4446
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
@@ -60,11 +62,16 @@
6062
]
6163

6264
RERANK_MODELS = [
63-
# classifier_pooling: mean
6465
CLSPoolingRerankModelInfo(
66+
# classifier_pooling: mean
6567
"Alibaba-NLP/gte-reranker-modernbert-base",
6668
architecture="ModernBertForSequenceClassification",
6769
enable_test=True),
70+
CLSPoolingRerankModelInfo(
71+
"Alibaba-NLP/gte-multilingual-reranker-base",
72+
architecture="GteNewForSequenceClassification",
73+
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
74+
enable_test=True),
6875
]
6976

7077

@@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
7582
check_transformers_version(model_info.name,
7683
max_transformers_version="4.53.2")
7784

78-
vllm_extra_kwargs: dict[str, Any] = {}
79-
if model_info.architecture == "GteNewModel":
80-
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
81-
82-
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
83-
vllm_extra_kwargs)
85+
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
8486

8587

8688
@pytest.mark.parametrize("model_info", MODELS)
@@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
9193
check_transformers_version(model_info.name,
9294
max_transformers_version="4.53.2")
9395

94-
vllm_extra_kwargs: dict[str, Any] = {}
95-
if model_info.architecture == "GteNewModel":
96-
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
97-
9896
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
99-
example_prompts, vllm_extra_kwargs)
97+
example_prompts)
10098

10199

102100
@pytest.mark.parametrize("model_info", RERANK_MODELS)

tests/models/language/pooling/test_mxbai_rerank.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,20 @@
1010
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
1111
from .mteb_utils import mteb_test_rerank_models
1212

13+
mxbai_rerank_hf_overrides = {
14+
"architectures": ["Qwen2ForSequenceClassification"],
15+
"classifier_from_token": ["0", "1"],
16+
"method": "from_2_way_softmax",
17+
}
18+
1319
RERANK_MODELS = [
1420
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
1521
architecture="Qwen2ForSequenceClassification",
22+
hf_overrides=mxbai_rerank_hf_overrides,
1623
enable_test=True),
1724
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
1825
architecture="Qwen2ForSequenceClassification",
26+
hf_overrides=mxbai_rerank_hf_overrides,
1927
enable_test=False)
2028
]
2129

@@ -71,13 +79,4 @@ def compute_logits(inputs):
7179

7280
@pytest.mark.parametrize("model_info", RERANK_MODELS)
7381
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
74-
vllm_extra_kwargs: dict[str, Any] = {}
75-
if model_info.architecture == "Qwen2ForSequenceClassification":
76-
vllm_extra_kwargs["hf_overrides"] = {
77-
"architectures": ["Qwen2ForSequenceClassification"],
78-
"classifier_from_token": ["0", "1"],
79-
"method": "from_2_way_softmax",
80-
}
81-
82-
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
83-
vllm_extra_kwargs)
82+
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info)

tests/models/language/pooling/test_qwen3_reranker.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,20 @@
1111
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
1212
from .mteb_utils import mteb_test_rerank_models
1313

14+
qwen3_reranker_hf_overrides = {
15+
"architectures": ["Qwen3ForSequenceClassification"],
16+
"classifier_from_token": ["no", "yes"],
17+
"is_original_qwen3_reranker": True,
18+
}
19+
1420
RERANK_MODELS = [
1521
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
1622
architecture="Qwen3ForSequenceClassification",
23+
hf_overrides=qwen3_reranker_hf_overrides,
1724
enable_test=True),
1825
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
1926
architecture="Qwen3ForSequenceClassification",
27+
hf_overrides=qwen3_reranker_hf_overrides,
2028
enable_test=False)
2129
]
2230

@@ -74,18 +82,7 @@ def compute_logits(inputs):
7482
@pytest.mark.parametrize("model_info", RERANK_MODELS)
7583
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
7684

77-
assert model_info.architecture == "Qwen3ForSequenceClassification"
78-
79-
vllm_extra_kwargs: dict[str, Any] = {
80-
"hf_overrides": {
81-
"architectures": ["Qwen3ForSequenceClassification"],
82-
"classifier_from_token": ["no", "yes"],
83-
"is_original_qwen3_reranker": True,
84-
}
85-
}
86-
87-
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
88-
vllm_extra_kwargs)
85+
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info)
8986

9087

9188
@pytest.mark.parametrize("model_info", RERANK_MODELS)
@@ -96,11 +93,6 @@ def test_rerank_models_mteb_tp(vllm_runner,
9693
assert model_info.architecture == "Qwen3ForSequenceClassification"
9794

9895
vllm_extra_kwargs: dict[str, Any] = {
99-
"hf_overrides": {
100-
"architectures": ["Qwen3ForSequenceClassification"],
101-
"classifier_from_token": ["no", "yes"],
102-
"is_original_qwen3_reranker": True,
103-
},
10496
"tensor_parallel_size": 2,
10597
}
10698

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,10 @@ def check_available_online(
365365

366366
# [Cross-encoder]
367367
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
368+
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
369+
trust_remote_code=True,
370+
hf_overrides={
371+
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
368372
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
369373
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
370374
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501

tests/models/utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import warnings
55
from collections.abc import Sequence
6-
from typing import Any, NamedTuple, Optional, Union
6+
from dataclasses import dataclass
7+
from typing import Any, Optional, Union
78

89
import torch
910
import torch.nn.functional as F
@@ -339,36 +340,43 @@ def softmax(data):
339340
return F.softmax(data, dim=-1)
340341

341342

342-
class EmbedModelInfo(NamedTuple):
343+
@dataclass
344+
class ModelInfo:
343345
name: str
344-
is_matryoshka: bool = False
345-
matryoshka_dimensions: Optional[list[int]] = None
346346
architecture: str = ""
347347
dtype: str = "auto"
348+
hf_overrides: Optional[dict[str, Any]] = None
348349
default_pooling_type: str = ""
349350
enable_test: bool = True
350351

351352

353+
@dataclass
354+
class EmbedModelInfo(ModelInfo):
355+
is_matryoshka: bool = False
356+
matryoshka_dimensions: Optional[list[int]] = None
357+
358+
359+
@dataclass
352360
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
353361
default_pooling_type: str = "CLS"
354362

355363

364+
@dataclass
356365
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
357366
default_pooling_type: str = "LAST"
358367

359368

360-
class RerankModelInfo(NamedTuple):
361-
name: str
362-
architecture: str = ""
363-
dtype: str = "auto"
364-
default_pooling_type: str = ""
365-
enable_test: bool = True
369+
@dataclass
370+
class RerankModelInfo(ModelInfo):
371+
pass
366372

367373

374+
@dataclass
368375
class CLSPoolingRerankModelInfo(RerankModelInfo):
369376
default_pooling_type: str = "CLS"
370377

371378

379+
@dataclass
372380
class LASTPoolingRerankModelInfo(RerankModelInfo):
373381
default_pooling_type: str = "LAST"
374382

0 commit comments

Comments
 (0)