Skip to content

Commit 4d2572e

Browse files
committed
+ GteNewForSequenceClassification
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 42fce02 commit 4d2572e

File tree

6 files changed

+114
-22
lines changed

6 files changed

+114
-22
lines changed

docs/models/supported_models.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
497497
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
498498
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
499499
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
500+
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ |
500501
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
501502
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
502503
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
@@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
513514
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
514515
```
515516

517+
!!! note
518+
The second-generation GTE model (mGTE-TRM) is named `GteNewForSequenceClassification`. The name `GteNewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
519+
516520
!!! note
517521
Load the official original `mxbai-rerank-v2` by using the following command.
518522

tests/models/language/pooling/test_gte.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,14 @@
6060
]
6161

6262
RERANK_MODELS = [
63-
# classifier_pooling: mean
6463
CLSPoolingRerankModelInfo(
64+
# classifier_pooling: mean
6565
"Alibaba-NLP/gte-reranker-modernbert-base",
6666
architecture="ModernBertForSequenceClassification",
6767
enable_test=True),
68+
CLSPoolingRerankModelInfo("Alibaba-NLP/gte-multilingual-reranker-base",
69+
architecture="GteNewForSequenceClassification",
70+
enable_test=True),
6871
]
6972

7073

@@ -102,4 +105,11 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
102105
@pytest.mark.parametrize("model_info", RERANK_MODELS)
103106
def test_rerank_models_mteb(hf_runner, vllm_runner,
104107
model_info: RerankModelInfo) -> None:
105-
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
108+
vllm_extra_kwargs: dict[str, Any] = {}
109+
if model_info.architecture == "GteNewForSequenceClassification":
110+
vllm_extra_kwargs["hf_overrides"] = {
111+
"architectures": ["GteNewForSequenceClassification"]
112+
}
113+
114+
mteb_test_rerank_models(hf_runner, vllm_runner, model_info,
115+
vllm_extra_kwargs)

tests/models/registry.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,8 @@ def check_available_online(
323323

324324
_EMBEDDING_EXAMPLE_MODELS = {
325325
# [Text-only]
326-
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
327-
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
326+
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
327+
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
328328
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
329329
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
330330
trust_remote_code=True),
@@ -337,19 +337,19 @@ def check_available_online(
337337
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
338338
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
339339
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
340-
trust_remote_code=True, v0_only=True),
340+
trust_remote_code=True),
341341
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
342-
trust_remote_code=True, v0_only=True), # noqa: E501
342+
trust_remote_code=True), # noqa: E501
343343
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
344344
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B",
345345
max_transformers_version="4.53",
346346
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
347347
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B",
348348
max_transformers_version="4.53",
349349
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
350-
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
351-
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
352-
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
350+
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
351+
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
352+
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501
353353
# [Multimodal]
354354
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
355355
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
@@ -364,16 +364,18 @@ def check_available_online(
364364
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
365365

366366
# [Cross-encoder]
367-
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
368-
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
369-
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
370-
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
367+
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
368+
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
369+
trust_remote_code=True,
370+
hf_overrides={"architectures": ["GteNewForSequenceClassification"]}), # noqa: E501
371+
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
372+
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
373+
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
371374
}
372375

373376
_AUTOMATIC_CONVERTED_MODELS = {
374377
# Use as_seq_cls_model for automatic conversion
375378
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
376-
v0_only=True,
377379
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
378380
"classifier_from_token": ["Yes"], # noqa: E501
379381
"method": "no_post_processing"}), # noqa: E501

vllm/model_executor/models/bert_with_rope.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,19 @@
2222
QKVParallelLinear,
2323
ReplicatedLinear,
2424
RowParallelLinear)
25+
from vllm.model_executor.layers.pooler import (ClassifierPooler,
26+
DispatchPooler, Pooler)
2527
from vllm.model_executor.layers.quantization import QuantizationConfig
2628
from vllm.model_executor.layers.rotary_embedding import get_rope
2729
from vllm.model_executor.layers.vocab_parallel_embedding import (
2830
VocabParallelEmbedding)
2931
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30-
from vllm.model_executor.models.interfaces import (SupportsQuant,
32+
from vllm.model_executor.models.bert import BertPooler
33+
from vllm.model_executor.models.interfaces import (SupportsCrossEncoding,
34+
SupportsQuant,
3135
default_pooling_type)
32-
from vllm.model_executor.models.utils import WeightsMapper
36+
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
37+
maybe_prefix)
3338
from vllm.model_executor.utils import set_weight_attrs
3439
from vllm.platforms import current_platform
3540
from vllm.sequence import IntermediateTensors
@@ -405,16 +410,22 @@ def forward(
405410
class BertWithRope(nn.Module, SupportsQuant):
406411
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
407412

408-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
413+
def __init__(self,
414+
*,
415+
vllm_config: VllmConfig,
416+
prefix: str = "",
417+
add_pooling_layer: bool = False):
409418
super().__init__()
410419
self.vllm_config = vllm_config
420+
self.add_pooling_layer = add_pooling_layer
411421
self.config = vllm_config.model_config.hf_config
412422
self.embeddings = BertWithRopeEmbedding(self.config)
413423
self.encoder = BertWithRopeEncoder(
414424
vllm_config=vllm_config,
415425
bias=getattr(self.config, "bias", True),
416426
rotary_kwargs=self.config.rotary_kwargs,
417427
prefix=f"{prefix}.encoder")
428+
self.pooler = BertPooler(self.config) if add_pooling_layer else None
418429

419430
def forward(
420431
self,
@@ -447,7 +458,7 @@ def load_weights(self, weights: Iterable[tuple[str,
447458
params_dict = dict(self.named_parameters())
448459
loaded_params: set[str] = set()
449460
for name, loaded_weight in weights:
450-
if "pooler" in name:
461+
if not self.add_pooling_layer and "pooler" in name:
451462
continue
452463
for (param_name, weight_name, shard_id) in stacked_params_mapping:
453464
if weight_name not in name:
@@ -507,8 +518,8 @@ class GteNewModel(BertWithRope):
507518
"attention.o_proj": "attn.out_proj",
508519
})
509520

510-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
511-
super().__init__(vllm_config=vllm_config, prefix=prefix)
521+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
522+
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
512523

513524
# GteNewModel only gate_up_proj does not have bias.
514525
# Hack method learned from vllm/model_executor/models/glm.py
@@ -613,3 +624,65 @@ def load_weights(self, weights: Iterable[tuple[str,
613624
torch.Tensor]]) -> set[str]:
614625
weights = self.jina_merge_lora_weights(weights)
615626
return super().load_weights(weights)
627+
628+
629+
@default_pooling_type("CLS")
630+
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
631+
is_pooling_model = True
632+
633+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
634+
super().__init__()
635+
config = vllm_config.model_config.hf_config
636+
quant_config = vllm_config.quant_config
637+
638+
self.new = GteNewModel(vllm_config=vllm_config,
639+
prefix=prefix,
640+
add_pooling_layer=True)
641+
self.classifier = RowParallelLinear(config.hidden_size,
642+
config.num_labels,
643+
input_is_parallel=False,
644+
bias=True,
645+
quant_config=quant_config,
646+
prefix=maybe_prefix(
647+
prefix, "classifier"),
648+
return_bias=False)
649+
650+
pooler_config = vllm_config.model_config.pooler_config
651+
assert pooler_config is not None
652+
653+
self.pooler = DispatchPooler({
654+
"encode":
655+
Pooler.for_encode(pooler_config),
656+
"classify":
657+
ClassifierPooler(
658+
pooling=self.new.pooler,
659+
classifier=self.classifier,
660+
act_fn=ClassifierPooler.act_fn_for_seq_cls(
661+
vllm_config.model_config),
662+
),
663+
"score":
664+
ClassifierPooler(
665+
pooling=self.new.pooler,
666+
classifier=self.classifier,
667+
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
668+
vllm_config.model_config),
669+
),
670+
})
671+
672+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
673+
loader = AutoWeightsLoader(self)
674+
loaded_params = loader.load_weights(weights)
675+
return loaded_params
676+
677+
def forward(
678+
self,
679+
input_ids: Optional[torch.Tensor],
680+
positions: torch.Tensor,
681+
intermediate_tensors: Optional[IntermediateTensors] = None,
682+
inputs_embeds: Optional[torch.Tensor] = None,
683+
) -> torch.Tensor:
684+
685+
return self.new(input_ids=input_ids,
686+
positions=positions,
687+
inputs_embeds=inputs_embeds,
688+
intermediate_tensors=intermediate_tensors)

vllm/model_executor/models/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
365365
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
366366
"GteModel": SnowflakeGteNewModelConfig,
367367
"GteNewModel": GteNewModelConfig,
368+
"GteNewForSequenceClassification": GteNewModelConfig,
368369
"NomicBertModel": NomicBertModelConfig,
369370
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
370371
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,

vllm/model_executor/models/registry.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,14 @@
188188

189189
_CROSS_ENCODER_MODELS = {
190190
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
191+
"GteNewForSequenceClassification": ("bert_with_rope",
192+
"GteNewForSequenceClassification"),
191193
"RobertaForSequenceClassification": ("roberta",
192194
"RobertaForSequenceClassification"),
193-
"XLMRobertaForSequenceClassification": ("roberta",
194-
"RobertaForSequenceClassification"),
195195
"ModernBertForSequenceClassification": ("modernbert",
196196
"ModernBertForSequenceClassification"),
197+
"XLMRobertaForSequenceClassification": ("roberta",
198+
"RobertaForSequenceClassification"),
197199
# [Auto-converted (see adapters.py)]
198200
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
199201
}

0 commit comments

Comments
 (0)