diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 20cf75873af7..34e465584888 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | +| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | @@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' ``` +!!! note + The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture. + !!! note Load the official original `mxbai-rerank-v2` by using the following command. diff --git a/tests/conftest.py b/tests/conftest.py index f8bfdfc8e625..fe329f54cce8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -454,11 +454,10 @@ def classify(self, prompts: list[str]) -> list[str]: # output is final logits all_inputs = self.get_inputs(prompts) outputs = [] + problem_type = getattr(self.config, "problem_type", "") + for inputs in all_inputs: output = self.model(**self.wrap_device(inputs)) - - problem_type = getattr(self.config, "problem_type", "") - if problem_type == "regression": logits = output.logits[0].tolist() elif problem_type == "multi_label_classification": diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index 61c5fcab4f8a..a74ad2aa2597 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 4a1f8a53d024..640858125bfc 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, @@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py index 206524d7caad..f473e0ba01ff 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -13,7 +13,14 @@ RERANK_MODELS = [ LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", - architecture="GemmaForSequenceClassification"), + architecture="GemmaForSequenceClassification", + hf_overrides={ + "architectures": + ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": + "no_post_processing", + }), ] PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 @@ -119,22 +126,9 @@ def predict( @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, - monkeypatch) -> None: - monkeypatch.setenv("VLLM_USE_V1", "0") - - assert model_info.architecture == "GemmaForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["GemmaForSequenceClassification"], - "classifier_from_token": ["Yes"], - "method": "no_post_processing", - } - } +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: mteb_test_rerank_models(GemmaRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs, vllm_mteb_encoder=GemmaMtebEncoder) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index f805a64103c0..9911620c018e 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any import pytest @@ -33,12 +32,15 @@ ########### NewModel CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), ########### Qwen2ForCausalLM LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", @@ -60,11 +62,16 @@ ] RERANK_MODELS = [ - # classifier_pooling: mean CLSPoolingRerankModelInfo( + # classifier_pooling: mean "Alibaba-NLP/gte-reranker-modernbert-base", architecture="ModernBertForSequenceClassification", enable_test=True), + CLSPoolingRerankModelInfo( + "Alibaba-NLP/gte-multilingual-reranker-base", + architecture="GteNewForSequenceClassification", + hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, + enable_test=True), ] @@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner, check_transformers_version(model_info.name, max_transformers_version="4.53.2") - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - - mteb_test_embed_models(hf_runner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) @@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner, check_transformers_version(model_info.name, max_transformers_version="4.53.2") - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts, vllm_extra_kwargs) + example_prompts) @pytest.mark.parametrize("model_info", RERANK_MODELS) diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py index 480bd5e4567c..73823deeff4e 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling/test_mxbai_rerank.py @@ -10,12 +10,20 @@ from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models +mxbai_rerank_hf_overrides = { + "architectures": ["Qwen2ForSequenceClassification"], + "classifier_from_token": ["0", "1"], + "method": "from_2_way_softmax", +} + RERANK_MODELS = [ LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, enable_test=True), LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, enable_test=False) ] @@ -71,13 +79,4 @@ def compute_logits(inputs): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "Qwen2ForSequenceClassification": - vllm_extra_kwargs["hf_overrides"] = { - "architectures": ["Qwen2ForSequenceClassification"], - "classifier_from_token": ["0", "1"], - "method": "from_2_way_softmax", - } - - mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 37f5566a330d..8c6537f3193f 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -11,12 +11,20 @@ from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models +qwen3_reranker_hf_overrides = { + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, +} + RERANK_MODELS = [ LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, enable_test=True), LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, enable_test=False) ] @@ -74,18 +82,7 @@ def compute_logits(inputs): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - assert model_info.architecture == "Qwen3ForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - } - } - - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", RERANK_MODELS) @@ -96,11 +93,6 @@ def test_rerank_models_mteb_tp(vllm_runner, assert model_info.architecture == "Qwen3ForSequenceClassification" vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - }, "tensor_parallel_size": 2, } diff --git a/tests/models/registry.py b/tests/models/registry.py index 2538e71692c4..85b4c96e3b1c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -365,6 +365,10 @@ def check_available_online( # [Cross-encoder] "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 + "GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501 + trust_remote_code=True, + hf_overrides={ + "architectures": ["GteNewForSequenceClassification"]}),# noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 diff --git a/tests/models/utils.py b/tests/models/utils.py index 84aeb927c5fa..0fb1f5b3753b 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -3,7 +3,8 @@ import warnings from collections.abc import Sequence -from typing import Any, NamedTuple, Optional, Union +from dataclasses import dataclass +from typing import Any, Optional, Union import torch import torch.nn.functional as F @@ -339,36 +340,43 @@ def softmax(data): return F.softmax(data, dim=-1) -class EmbedModelInfo(NamedTuple): +@dataclass +class ModelInfo: name: str - is_matryoshka: bool = False - matryoshka_dimensions: Optional[list[int]] = None architecture: str = "" dtype: str = "auto" + hf_overrides: Optional[dict[str, Any]] = None default_pooling_type: str = "" enable_test: bool = True +@dataclass +class EmbedModelInfo(ModelInfo): + is_matryoshka: bool = False + matryoshka_dimensions: Optional[list[int]] = None + + +@dataclass class CLSPoolingEmbedModelInfo(EmbedModelInfo): default_pooling_type: str = "CLS" +@dataclass class LASTPoolingEmbedModelInfo(EmbedModelInfo): default_pooling_type: str = "LAST" -class RerankModelInfo(NamedTuple): - name: str - architecture: str = "" - dtype: str = "auto" - default_pooling_type: str = "" - enable_test: bool = True +@dataclass +class RerankModelInfo(ModelInfo): + pass +@dataclass class CLSPoolingRerankModelInfo(RerankModelInfo): default_pooling_type: str = "CLS" +@dataclass class LASTPoolingRerankModelInfo(RerankModelInfo): default_pooling_type: str = "LAST" diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index dcb7e75456cd..3be7e11d947d 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -27,12 +27,15 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.utils import WeightsMapper +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + maybe_prefix) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .interfaces import SupportsQuant +from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler +from .bert import BertPooler +from .interfaces import SupportsCrossEncoding, SupportsQuant from .interfaces_base import default_pooling_type @@ -406,9 +409,14 @@ def forward( class BertWithRope(nn.Module, SupportsQuant): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + add_pooling_layer: bool = False): super().__init__() self.vllm_config = vllm_config + self.add_pooling_layer = add_pooling_layer self.config = vllm_config.model_config.hf_config self.embeddings = BertWithRopeEmbedding(self.config) self.encoder = BertWithRopeEncoder( @@ -416,6 +424,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): bias=getattr(self.config, "bias", True), rotary_kwargs=self.config.rotary_kwargs, prefix=f"{prefix}.encoder") + self.pooler = BertPooler(self.config) if add_pooling_layer else None def forward( self, @@ -448,7 +457,7 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "pooler" in name: + if not self.add_pooling_layer and "pooler" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -508,8 +517,8 @@ class GteNewModel(BertWithRope): "attention.o_proj": "attn.out_proj", }) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) # GteNewModel only gate_up_proj does not have bias. # Hack method learned from vllm/model_executor/models/glm.py @@ -614,3 +623,65 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.jina_merge_lora_weights(weights) return super().load_weights(weights) + + +@default_pooling_type("CLS") +class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.new = GteNewModel(vllm_config=vllm_config, + prefix=prefix, + add_pooling_layer=True) + self.classifier = RowParallelLinear(config.hidden_size, + config.num_labels, + input_is_parallel=False, + bias=True, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "classifier"), + return_bias=False) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(weights) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + return self.new(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index b0dbfacece3a..377b7bf26a07 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -406,6 +406,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, + "GteNewForSequenceClassification": GteNewModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 12c0c77784db..9040189ee558 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -191,12 +191,14 @@ _CROSS_ENCODER_MODELS = { "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "GteNewForSequenceClassification": ("bert_with_rope", + "GteNewForSequenceClassification"), + "ModernBertForSequenceClassification": ("modernbert", + "ModernBertForSequenceClassification"), "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), "XLMRobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), - "ModernBertForSequenceClassification": ("modernbert", - "ModernBertForSequenceClassification"), # [Auto-converted (see adapters.py)] "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, }