Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ |
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
Expand All @@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
```

!!! note
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.

!!! note
Load the official original `mxbai-rerank-v2` by using the following command.

Expand Down
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,10 @@ def classify(self, prompts: list[str]) -> list[str]:
# output is final logits
all_inputs = self.get_inputs(prompts)
outputs = []
problem_type = getattr(self.config, "problem_type", "")

for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs))

problem_type = getattr(self.config, "problem_type", "")

if problem_type == "regression":
logits = output.logits[0].tolist()
elif problem_type == "multi_label_classification":
Expand Down
3 changes: 3 additions & 0 deletions tests/models/language/pooling/embed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype

if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides

with vllm_runner(model_info.name,
runner="pooling",
max_model_len=None,
Expand Down
6 changes: 6 additions & 0 deletions tests/models/language/pooling/mteb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype

if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides

with vllm_runner(model_info.name,
runner="pooling",
max_model_len=None,
Expand Down Expand Up @@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype

if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides

with vllm_runner(model_info.name,
runner="pooling",
max_model_len=None,
Expand Down
24 changes: 9 additions & 15 deletions tests/models/language/pooling/test_bge_reranker_v2_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@

RERANK_MODELS = [
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification"),
architecture="GemmaForSequenceClassification",
hf_overrides={
"architectures":
["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method":
"no_post_processing",
}),
]

PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
Expand Down Expand Up @@ -119,22 +126,9 @@ def predict(


@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo,
monkeypatch) -> None:
monkeypatch.setenv("VLLM_USE_V1", "0")

assert model_info.architecture == "GemmaForSequenceClassification"

vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method": "no_post_processing",
}
}
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:

mteb_test_rerank_models(GemmaRerankerHfRunner,
vllm_runner,
model_info,
vllm_extra_kwargs,
vllm_mteb_encoder=GemmaMtebEncoder)
24 changes: 11 additions & 13 deletions tests/models/language/pooling/test_gte.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

import pytest

Expand Down Expand Up @@ -33,12 +32,15 @@
########### NewModel
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True),
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True),
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True),
########### Qwen2ForCausalLM
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
Expand All @@ -60,11 +62,16 @@
]

RERANK_MODELS = [
# classifier_pooling: mean
CLSPoolingRerankModelInfo(
# classifier_pooling: mean
"Alibaba-NLP/gte-reranker-modernbert-base",
architecture="ModernBertForSequenceClassification",
enable_test=True),
CLSPoolingRerankModelInfo(
"Alibaba-NLP/gte-multilingual-reranker-base",
architecture="GteNewForSequenceClassification",
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
enable_test=True),
]


Expand All @@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
check_transformers_version(model_info.name,
max_transformers_version="4.53.2")

vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}

mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)
mteb_test_embed_models(hf_runner, vllm_runner, model_info)


@pytest.mark.parametrize("model_info", MODELS)
Expand All @@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
check_transformers_version(model_info.name,
max_transformers_version="4.53.2")

vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}

correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts, vllm_extra_kwargs)
example_prompts)


@pytest.mark.parametrize("model_info", RERANK_MODELS)
Expand Down
19 changes: 9 additions & 10 deletions tests/models/language/pooling/test_mxbai_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import mteb_test_rerank_models

mxbai_rerank_hf_overrides = {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
}

RERANK_MODELS = [
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
enable_test=True),
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
enable_test=False)
]

Expand Down Expand Up @@ -71,13 +79,4 @@ def compute_logits(inputs):

@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "Qwen2ForSequenceClassification":
vllm_extra_kwargs["hf_overrides"] = {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
}

mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info)
26 changes: 9 additions & 17 deletions tests/models/language/pooling/test_qwen3_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import mteb_test_rerank_models

qwen3_reranker_hf_overrides = {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}

RERANK_MODELS = [
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
architecture="Qwen3ForSequenceClassification",
hf_overrides=qwen3_reranker_hf_overrides,
enable_test=True),
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
architecture="Qwen3ForSequenceClassification",
hf_overrides=qwen3_reranker_hf_overrides,
enable_test=False)
]

Expand Down Expand Up @@ -74,18 +82,7 @@ def compute_logits(inputs):
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:

assert model_info.architecture == "Qwen3ForSequenceClassification"

vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
}

mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info)


@pytest.mark.parametrize("model_info", RERANK_MODELS)
Expand All @@ -96,11 +93,6 @@ def test_rerank_models_mteb_tp(vllm_runner,
assert model_info.architecture == "Qwen3ForSequenceClassification"

vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"tensor_parallel_size": 2,
}

Expand Down
4 changes: 4 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@ def check_available_online(

# [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
trust_remote_code=True,
hf_overrides={
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
Expand Down
28 changes: 18 additions & 10 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import warnings
from collections.abc import Sequence
from typing import Any, NamedTuple, Optional, Union
from dataclasses import dataclass
from typing import Any, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -339,36 +340,43 @@ def softmax(data):
return F.softmax(data, dim=-1)


class EmbedModelInfo(NamedTuple):
@dataclass
class ModelInfo:
name: str
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = ""
dtype: str = "auto"
hf_overrides: Optional[dict[str, Any]] = None
default_pooling_type: str = ""
enable_test: bool = True


@dataclass
class EmbedModelInfo(ModelInfo):
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None


@dataclass
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "CLS"


@dataclass
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "LAST"


class RerankModelInfo(NamedTuple):
name: str
architecture: str = ""
dtype: str = "auto"
default_pooling_type: str = ""
enable_test: bool = True
@dataclass
class RerankModelInfo(ModelInfo):
pass


@dataclass
class CLSPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "CLS"


@dataclass
class LASTPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "LAST"

Expand Down
Loading