From 42fce02da95f9d258a3647f64f056edc899ea94a Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 13:08:01 +0800 Subject: [PATCH 01/16] Score API Signed-off-by: wang.yuqi --- tests/conftest.py | 5 ++-- .../entrypoints/openai/test_classification.py | 30 +++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 6 +--- vllm/v1/worker/gpu_model_runner.py | 7 +++++ 4 files changed, 40 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2bf88abb0f6c..03a0aeaca92c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -454,11 +454,10 @@ def classify(self, prompts: list[str]) -> list[str]: # output is final logits all_inputs = self.get_inputs(prompts) outputs = [] + problem_type = getattr(self.config, "problem_type", "") + for inputs in all_inputs: output = self.model(**self.wrap_device(inputs)) - - problem_type = getattr(self.config, "problem_type", "") - if problem_type == "regression": logits = output.logits[0].tolist() elif problem_type == "multi_label_classification": diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index 30078fe90257..36c96d76c2e5 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -226,3 +226,33 @@ def test_pooling(server: RemoteOpenAIServer, model_name: str): }, ) assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_score(server: RemoteOpenAIServer, model_name: str): + # score api is only enabled for num_labels == 1. + response = requests.post( + server.url_for("score"), + json={ + "model": model_name, + "text_1": "ping", + "text_2": "pong", + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_rerank(server: RemoteOpenAIServer, model_name: str): + # rerank api is only enabled for num_labels == 1. + response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": "ping", + "documents": ["pong"], + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 14ba8aa64183..248500c2522e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1797,16 +1797,12 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger, ) if "classify" in supported_tasks else None - - enable_serving_reranking = ("classify" in supported_tasks and getattr( - model_config.hf_config, "num_labels", 0) == 1) state.openai_serving_scores = ServingScores( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if ("embed" in supported_tasks or enable_serving_reranking) else None - + ) if ("embed" in supported_tasks or "score" in supported_tasks) else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ec9887b8010a..58928cca312b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1302,6 +1302,13 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: "Please turn off chunked prefill by " "`--no-enable-chunked-prefill` before using it.") + if "score" in supported_tasks: + num_labels = getattr( + self.model_config.hf_config, "num_labels", 0) + if num_labels != 1: + supported_tasks.remove("score") + logger.info_once("Score API is only enabled for num_labels == 1.") + return supported_tasks def get_supported_tasks(self) -> tuple[SupportedTask, ...]: From 4d2572e9cad394bc6636ca8dba017c4d333b5300 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 14:02:04 +0800 Subject: [PATCH 02/16] + GteNewForSequenceClassification Signed-off-by: wang.yuqi --- docs/models/supported_models.md | 4 + tests/models/language/pooling/test_gte.py | 14 +++- tests/models/registry.py | 26 +++--- vllm/model_executor/models/bert_with_rope.py | 85 ++++++++++++++++++-- vllm/model_executor/models/config.py | 1 + vllm/model_executor/models/registry.py | 6 +- 6 files changed, 114 insertions(+), 22 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 8fb1019f2bdf..09d6f21ef186 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | +| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | @@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' ``` +!!! note + The second-generation GTE model (mGTE-TRM) is named `GteNewForSequenceClassification`. The name `GteNewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture. + !!! note Load the official original `mxbai-rerank-v2` by using the following command. diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index f805a64103c0..d5b09e2c855f 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -60,11 +60,14 @@ ] RERANK_MODELS = [ - # classifier_pooling: mean CLSPoolingRerankModelInfo( + # classifier_pooling: mean "Alibaba-NLP/gte-reranker-modernbert-base", architecture="ModernBertForSequenceClassification", enable_test=True), + CLSPoolingRerankModelInfo("Alibaba-NLP/gte-multilingual-reranker-base", + architecture="GteNewForSequenceClassification", + enable_test=True), ] @@ -102,4 +105,11 @@ def test_embed_models_correctness(hf_runner, vllm_runner, @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(hf_runner, vllm_runner, model_info: RerankModelInfo) -> None: - mteb_test_rerank_models(hf_runner, vllm_runner, model_info) + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.architecture == "GteNewForSequenceClassification": + vllm_extra_kwargs["hf_overrides"] = { + "architectures": ["GteNewForSequenceClassification"] + } + + mteb_test_rerank_models(hf_runner, vllm_runner, model_info, + vllm_extra_kwargs) diff --git a/tests/models/registry.py b/tests/models/registry.py index b34c6f2e5dc8..bd3020e6fe88 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -323,8 +323,8 @@ def check_available_online( _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] - "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), - "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 + "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), + "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501 "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), @@ -337,9 +337,9 @@ def check_available_online( "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", - trust_remote_code=True, v0_only=True), + trust_remote_code=True), "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", - trust_remote_code=True, v0_only=True), # noqa: E501 + trust_remote_code=True), # noqa: E501 "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B", max_transformers_version="4.53", @@ -347,9 +347,9 @@ def check_available_online( "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B", max_transformers_version="4.53", transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 - "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 - "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 - "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 + "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 + "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 + "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501 # [Multimodal] "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", @@ -364,16 +364,18 @@ def check_available_online( "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 # [Cross-encoder] - "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 - "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 - "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 - "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 + "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 + "GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501 + trust_remote_code=True, + hf_overrides={"architectures": ["GteNewForSequenceClassification"]}), # noqa: E501 + "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501 + "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 + "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 } _AUTOMATIC_CONVERTED_MODELS = { # Use as_seq_cls_model for automatic conversion "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 - v0_only=True, hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 "classifier_from_token": ["Yes"], # noqa: E501 "method": "no_post_processing"}), # noqa: E501 diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 129450927e56..bd2de885fb55 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -22,14 +22,19 @@ QKVParallelLinear, ReplicatedLinear, RowParallelLinear) +from vllm.model_executor.layers.pooler import (ClassifierPooler, + DispatchPooler, Pooler) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (SupportsQuant, +from vllm.model_executor.models.bert import BertPooler +from vllm.model_executor.models.interfaces import (SupportsCrossEncoding, + SupportsQuant, default_pooling_type) -from vllm.model_executor.models.utils import WeightsMapper +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + maybe_prefix) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -405,9 +410,14 @@ def forward( class BertWithRope(nn.Module, SupportsQuant): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + add_pooling_layer: bool = False): super().__init__() self.vllm_config = vllm_config + self.add_pooling_layer = add_pooling_layer self.config = vllm_config.model_config.hf_config self.embeddings = BertWithRopeEmbedding(self.config) self.encoder = BertWithRopeEncoder( @@ -415,6 +425,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): bias=getattr(self.config, "bias", True), rotary_kwargs=self.config.rotary_kwargs, prefix=f"{prefix}.encoder") + self.pooler = BertPooler(self.config) if add_pooling_layer else None def forward( self, @@ -447,7 +458,7 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "pooler" in name: + if not self.add_pooling_layer and "pooler" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -507,8 +518,8 @@ class GteNewModel(BertWithRope): "attention.o_proj": "attn.out_proj", }) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) # GteNewModel only gate_up_proj does not have bias. # Hack method learned from vllm/model_executor/models/glm.py @@ -613,3 +624,65 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.jina_merge_lora_weights(weights) return super().load_weights(weights) + + +@default_pooling_type("CLS") +class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.new = GteNewModel(vllm_config=vllm_config, + prefix=prefix, + add_pooling_layer=True) + self.classifier = RowParallelLinear(config.hidden_size, + config.num_labels, + input_is_parallel=False, + bias=True, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "classifier"), + return_bias=False) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(weights) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + return self.new(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 882df7e8162c..f85c4c0958a0 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -365,6 +365,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, + "GteNewForSequenceClassification": GteNewModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ebf78771e40a..43b90f9a48ff 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -188,12 +188,14 @@ _CROSS_ENCODER_MODELS = { "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "GteNewForSequenceClassification": ("bert_with_rope", + "GteNewForSequenceClassification"), "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), - "XLMRobertaForSequenceClassification": ("roberta", - "RobertaForSequenceClassification"), "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), + "XLMRobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), # [Auto-converted (see adapters.py)] "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, } From a8d845fe4662de478038dad7614ce8705149eb76 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 16:15:02 +0800 Subject: [PATCH 03/16] baseline Signed-off-by: wang.yuqi --- vllm/model_executor/models/config.py | 2 ++ vllm/model_executor/models/qwen2_rm.py | 2 -- vllm/v1/worker/gpu_model_runner.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index f85c4c0958a0..58cc8ee8a827 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -168,6 +168,7 @@ class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: + vllm_config.model_config.hf_config.num_labels = 2 pooler_config = vllm_config.model_config.pooler_config if pooler_config.step_tag_id is None: @@ -178,6 +179,7 @@ class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: + vllm_config.model_config.hf_config.num_labels = 1 pooler_config = vllm_config.model_config.pooler_config if pooler_config.softmax is None: diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index e0a30e04c602..cb3006ff1750 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -93,7 +93,6 @@ def load_weights(self, weights: Iterable[tuple[str, class Qwen2ForRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - vllm_config.model_config.hf_config.num_labels = 1 super().__init__(vllm_config=vllm_config, prefix=prefix) pooler_config = vllm_config.model_config.pooler_config @@ -107,7 +106,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - vllm_config.model_config.hf_config.num_labels = 2 super().__init__(vllm_config=vllm_config, prefix=prefix) pooler_config = vllm_config.model_config.pooler_config diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 58928cca312b..29c0b9643b5a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1303,11 +1303,11 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: "`--no-enable-chunked-prefill` before using it.") if "score" in supported_tasks: - num_labels = getattr( - self.model_config.hf_config, "num_labels", 0) + num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: supported_tasks.remove("score") - logger.info_once("Score API is only enabled for num_labels == 1.") + logger.info_once( + "Score API is only enabled for num_labels == 1.") return supported_tasks From 708962c3abf75164e25c34396a3e955ada9976b4 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 16:47:09 +0800 Subject: [PATCH 04/16] Removing pooled_data.shape[-1] causes CUDA sync Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 30 ++++++++++++++++++---------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index eebf7b2508db..7c93696cb9a8 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -389,9 +389,15 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerClassify(PoolerActivation): + def __init__(self): + super().__init__() + + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + self.num_labels = getattr(vllm_config.model_config.hf_config, "num_labels", 0) + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - num_labels = pooled_data.shape[-1] - if num_labels < 2: + if self.num_labels == 1: return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) @@ -399,9 +405,15 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerScore(PoolerActivation): + def __init__(self): + super().__init__() + + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + self.num_labels = getattr(vllm_config.model_config.hf_config, "num_labels", 0) + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - num_labels = pooled_data.shape[-1] - if num_labels < 2: + if self.num_labels == 1: return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) return pooled_data @@ -661,13 +673,8 @@ def forward( # pooled_data shape: [batchsize, hidden_size] if self.classifier is not None: - # apply classifier once on the full batch if possible - if isinstance(pooled_data, torch.Tensor): - pooled_data = self.classifier(pooled_data) - elif len({data.shape for data in pooled_data}) <= 1: - pooled_data = self.classifier(torch.stack(pooled_data)) - else: - pooled_data = [self.classifier(data) for data in pooled_data] + pooled_data = self.classifier(pooled_data) + # pooled_data shape: [batchsize, num_labels] pooling_params = get_pooling_params(pooling_metadata) flags = [p.activation for p in pooling_params] @@ -680,6 +687,7 @@ def forward( for vecs, f in zip(pooled_data, flags) ] + # scores shape: [batchsize, num_labels] return build_output(scores) From 91190ad4d02ac095c2d9d73e058a095d30e71339 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 17:44:50 +0800 Subject: [PATCH 05/16] hf_overrides Signed-off-by: wang.yuqi --- tests/models/language/pooling/embed_utils.py | 3 ++ tests/models/language/pooling/mteb_utils.py | 6 ++++ tests/models/language/pooling/test_gte.py | 34 +++++++------------- tests/models/utils.py | 18 +++++------ 4 files changed, 29 insertions(+), 32 deletions(-) diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index 61c5fcab4f8a..a74ad2aa2597 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 4a1f8a53d024..640858125bfc 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, @@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index d5b09e2c855f..9911620c018e 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any import pytest @@ -33,12 +32,15 @@ ########### NewModel CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), ########### Qwen2ForCausalLM LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", @@ -65,9 +67,11 @@ "Alibaba-NLP/gte-reranker-modernbert-base", architecture="ModernBertForSequenceClassification", enable_test=True), - CLSPoolingRerankModelInfo("Alibaba-NLP/gte-multilingual-reranker-base", - architecture="GteNewForSequenceClassification", - enable_test=True), + CLSPoolingRerankModelInfo( + "Alibaba-NLP/gte-multilingual-reranker-base", + architecture="GteNewForSequenceClassification", + hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, + enable_test=True), ] @@ -78,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner, check_transformers_version(model_info.name, max_transformers_version="4.53.2") - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - - mteb_test_embed_models(hf_runner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) @@ -94,22 +93,11 @@ def test_embed_models_correctness(hf_runner, vllm_runner, check_transformers_version(model_info.name, max_transformers_version="4.53.2") - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts, vllm_extra_kwargs) + example_prompts) @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(hf_runner, vllm_runner, model_info: RerankModelInfo) -> None: - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewForSequenceClassification": - vllm_extra_kwargs["hf_overrides"] = { - "architectures": ["GteNewForSequenceClassification"] - } - - mteb_test_rerank_models(hf_runner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/utils.py b/tests/models/utils.py index 84aeb927c5fa..9330bacbceff 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -339,16 +339,20 @@ def softmax(data): return F.softmax(data, dim=-1) -class EmbedModelInfo(NamedTuple): +class ModelInfo(NamedTuple): name: str - is_matryoshka: bool = False - matryoshka_dimensions: Optional[list[int]] = None architecture: str = "" dtype: str = "auto" + hf_overrides: Optional[dict[str, Any]] = None default_pooling_type: str = "" enable_test: bool = True +class EmbedModelInfo(ModelInfo): + is_matryoshka: bool = False + matryoshka_dimensions: Optional[list[int]] = None + + class CLSPoolingEmbedModelInfo(EmbedModelInfo): default_pooling_type: str = "CLS" @@ -357,12 +361,8 @@ class LASTPoolingEmbedModelInfo(EmbedModelInfo): default_pooling_type: str = "LAST" -class RerankModelInfo(NamedTuple): - name: str - architecture: str = "" - dtype: str = "auto" - default_pooling_type: str = "" - enable_test: bool = True +class RerankModelInfo(ModelInfo): + pass class CLSPoolingRerankModelInfo(RerankModelInfo): From e2cd559a28a8944e2e5c96af16117e701a122ba9 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 17:50:00 +0800 Subject: [PATCH 06/16] separate Signed-off-by: wang.yuqi --- .../entrypoints/openai/test_classification.py | 30 ------------------- vllm/entrypoints/openai/api_server.py | 6 +++- vllm/model_executor/layers/pooler.py | 30 +++++++------------ vllm/v1/worker/gpu_model_runner.py | 7 ----- 4 files changed, 16 insertions(+), 57 deletions(-) diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index 36c96d76c2e5..30078fe90257 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -226,33 +226,3 @@ def test_pooling(server: RemoteOpenAIServer, model_name: str): }, ) assert response.json()["error"]["type"] == "BadRequestError" - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_score(server: RemoteOpenAIServer, model_name: str): - # score api is only enabled for num_labels == 1. - response = requests.post( - server.url_for("score"), - json={ - "model": model_name, - "text_1": "ping", - "text_2": "pong", - }, - ) - assert response.json()["error"]["type"] == "BadRequestError" - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_rerank(server: RemoteOpenAIServer, model_name: str): - # rerank api is only enabled for num_labels == 1. - response = requests.post( - server.url_for("rerank"), - json={ - "model": model_name, - "query": "ping", - "documents": ["pong"], - }, - ) - assert response.json()["error"]["type"] == "BadRequestError" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 248500c2522e..14ba8aa64183 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1797,12 +1797,16 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger, ) if "classify" in supported_tasks else None + + enable_serving_reranking = ("classify" in supported_tasks and getattr( + model_config.hf_config, "num_labels", 0) == 1) state.openai_serving_scores = ServingScores( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if ("embed" in supported_tasks or "score" in supported_tasks) else None + ) if ("embed" in supported_tasks or enable_serving_reranking) else None + state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 7c93696cb9a8..eebf7b2508db 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -389,15 +389,9 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerClassify(PoolerActivation): - def __init__(self): - super().__init__() - - from vllm.config import get_current_vllm_config - vllm_config = get_current_vllm_config() - self.num_labels = getattr(vllm_config.model_config.hf_config, "num_labels", 0) - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - if self.num_labels == 1: + num_labels = pooled_data.shape[-1] + if num_labels < 2: return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) @@ -405,15 +399,9 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerScore(PoolerActivation): - def __init__(self): - super().__init__() - - from vllm.config import get_current_vllm_config - vllm_config = get_current_vllm_config() - self.num_labels = getattr(vllm_config.model_config.hf_config, "num_labels", 0) - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - if self.num_labels == 1: + num_labels = pooled_data.shape[-1] + if num_labels < 2: return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) return pooled_data @@ -673,8 +661,13 @@ def forward( # pooled_data shape: [batchsize, hidden_size] if self.classifier is not None: - pooled_data = self.classifier(pooled_data) - # pooled_data shape: [batchsize, num_labels] + # apply classifier once on the full batch if possible + if isinstance(pooled_data, torch.Tensor): + pooled_data = self.classifier(pooled_data) + elif len({data.shape for data in pooled_data}) <= 1: + pooled_data = self.classifier(torch.stack(pooled_data)) + else: + pooled_data = [self.classifier(data) for data in pooled_data] pooling_params = get_pooling_params(pooling_metadata) flags = [p.activation for p in pooling_params] @@ -687,7 +680,6 @@ def forward( for vecs, f in zip(pooled_data, flags) ] - # scores shape: [batchsize, num_labels] return build_output(scores) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 08e76ddd9189..73117c75b9af 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1297,13 +1297,6 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: "Please turn off chunked prefill by " "`--no-enable-chunked-prefill` before using it.") - if "score" in supported_tasks: - num_labels = getattr(self.model_config.hf_config, "num_labels", 0) - if num_labels != 1: - supported_tasks.remove("score") - logger.info_once( - "Score API is only enabled for num_labels == 1.") - return supported_tasks def get_supported_tasks(self) -> tuple[SupportedTask, ...]: From 7d84d64ef50cf83ac226bd1de58a75064b2a7671 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 17:59:16 +0800 Subject: [PATCH 07/16] more hf_overrides Signed-off-by: wang.yuqi --- .../pooling/test_bge_reranker_v2_gemma.py | 20 +++++------ .../language/pooling/test_mxbai_rerank.py | 19 +++++------ .../language/pooling/test_qwen3_reranker.py | 33 +++++-------------- 3 files changed, 26 insertions(+), 46 deletions(-) diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py index 206524d7caad..4dc9601ecb7a 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -13,7 +13,14 @@ RERANK_MODELS = [ LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", - architecture="GemmaForSequenceClassification"), + architecture="GemmaForSequenceClassification", + hf_overrides={ + "architectures": + ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": + "no_post_processing", + }), ] PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 @@ -123,18 +130,7 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, monkeypatch) -> None: monkeypatch.setenv("VLLM_USE_V1", "0") - assert model_info.architecture == "GemmaForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["GemmaForSequenceClassification"], - "classifier_from_token": ["Yes"], - "method": "no_post_processing", - } - } - mteb_test_rerank_models(GemmaRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs, vllm_mteb_encoder=GemmaMtebEncoder) diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py index 480bd5e4567c..73823deeff4e 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling/test_mxbai_rerank.py @@ -10,12 +10,20 @@ from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models +mxbai_rerank_hf_overrides = { + "architectures": ["Qwen2ForSequenceClassification"], + "classifier_from_token": ["0", "1"], + "method": "from_2_way_softmax", +} + RERANK_MODELS = [ LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, enable_test=True), LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, enable_test=False) ] @@ -71,13 +79,4 @@ def compute_logits(inputs): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "Qwen2ForSequenceClassification": - vllm_extra_kwargs["hf_overrides"] = { - "architectures": ["Qwen2ForSequenceClassification"], - "classifier_from_token": ["0", "1"], - "method": "from_2_way_softmax", - } - - mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 37f5566a330d..86dbe6de011b 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -11,12 +11,20 @@ from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models +qwen3_reranker_hf_overrides = { + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, +} + RERANK_MODELS = [ LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, enable_test=True), LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, enable_test=False) ] @@ -74,18 +82,7 @@ def compute_logits(inputs): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - assert model_info.architecture == "Qwen3ForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - } - } - - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", RERANK_MODELS) @@ -93,19 +90,7 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: def test_rerank_models_mteb_tp(vllm_runner, model_info: RerankModelInfo) -> None: - assert model_info.architecture == "Qwen3ForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - }, - "tensor_parallel_size": 2, - } - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs, atol=1.2e-2) From e22a3c427707a36c9cc2a45861f5732dc79c5afb Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 18:16:52 +0800 Subject: [PATCH 08/16] dataclass using ModelInfo Signed-off-by: wang.yuqi --- tests/models/utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/models/utils.py b/tests/models/utils.py index 9330bacbceff..0fb1f5b3753b 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -3,7 +3,8 @@ import warnings from collections.abc import Sequence -from typing import Any, NamedTuple, Optional, Union +from dataclasses import dataclass +from typing import Any, Optional, Union import torch import torch.nn.functional as F @@ -339,7 +340,8 @@ def softmax(data): return F.softmax(data, dim=-1) -class ModelInfo(NamedTuple): +@dataclass +class ModelInfo: name: str architecture: str = "" dtype: str = "auto" @@ -348,27 +350,33 @@ class ModelInfo(NamedTuple): enable_test: bool = True +@dataclass class EmbedModelInfo(ModelInfo): is_matryoshka: bool = False matryoshka_dimensions: Optional[list[int]] = None +@dataclass class CLSPoolingEmbedModelInfo(EmbedModelInfo): default_pooling_type: str = "CLS" +@dataclass class LASTPoolingEmbedModelInfo(EmbedModelInfo): default_pooling_type: str = "LAST" +@dataclass class RerankModelInfo(ModelInfo): pass +@dataclass class CLSPoolingRerankModelInfo(RerankModelInfo): default_pooling_type: str = "CLS" +@dataclass class LASTPoolingRerankModelInfo(RerankModelInfo): default_pooling_type: str = "LAST" From 400de6cf4230323bf1853589208c7acdaab84f34 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 18:47:26 +0800 Subject: [PATCH 09/16] - v0_only Signed-off-by: wang.yuqi --- tests/models/registry.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index bd3020e6fe88..3ae7884ab9b8 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -323,8 +323,8 @@ def check_available_online( _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] - "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), - "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501 + "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), + "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), @@ -337,9 +337,9 @@ def check_available_online( "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", - trust_remote_code=True), + trust_remote_code=True, v0_only=True), "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", - trust_remote_code=True), # noqa: E501 + trust_remote_code=True, v0_only=True), # noqa: E501 "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B", max_transformers_version="4.53", @@ -347,9 +347,9 @@ def check_available_online( "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B", max_transformers_version="4.53", transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 - "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 - "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 - "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501 + "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 + "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 + "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 # [Multimodal] "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", @@ -364,18 +364,20 @@ def check_available_online( "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 # [Cross-encoder] - "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 - "GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501 - trust_remote_code=True, - hf_overrides={"architectures": ["GteNewForSequenceClassification"]}), # noqa: E501 - "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501 - "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 - "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 + "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 + "GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501 + trust_remote_code=True, + hf_overrides={ + "architectures": ["GteNewForSequenceClassification"]}),# noqa: E501 + "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 + "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 + "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 } _AUTOMATIC_CONVERTED_MODELS = { # Use as_seq_cls_model for automatic conversion "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 + v0_only=True, hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 "classifier_from_token": ["Yes"], # noqa: E501 "method": "no_post_processing"}), # noqa: E501 From e19198ab6c5692f1a5086ce23ce52260d7de4a5c Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 22:34:02 +0800 Subject: [PATCH 10/16] - config.py Signed-off-by: wang.yuqi --- vllm/model_executor/models/config.py | 2 -- vllm/model_executor/models/qwen2_rm.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 58cc8ee8a827..f85c4c0958a0 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -168,7 +168,6 @@ class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: - vllm_config.model_config.hf_config.num_labels = 2 pooler_config = vllm_config.model_config.pooler_config if pooler_config.step_tag_id is None: @@ -179,7 +178,6 @@ class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: - vllm_config.model_config.hf_config.num_labels = 1 pooler_config = vllm_config.model_config.pooler_config if pooler_config.softmax is None: diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index cb3006ff1750..e0a30e04c602 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -93,6 +93,7 @@ def load_weights(self, weights: Iterable[tuple[str, class Qwen2ForRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + vllm_config.model_config.hf_config.num_labels = 1 super().__init__(vllm_config=vllm_config, prefix=prefix) pooler_config = vllm_config.model_config.pooler_config @@ -106,6 +107,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + vllm_config.model_config.hf_config.num_labels = 2 super().__init__(vllm_config=vllm_config, prefix=prefix) pooler_config = vllm_config.model_config.pooler_config From ffea6007d442e43ab72701f7ea7bc68abca034f6 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 22:37:39 +0800 Subject: [PATCH 11/16] - monkeypatch Signed-off-by: wang.yuqi --- tests/models/language/pooling/test_bge_reranker_v2_gemma.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py index 4dc9601ecb7a..f473e0ba01ff 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -126,9 +126,7 @@ def predict( @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, - monkeypatch) -> None: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: mteb_test_rerank_models(GemmaRerankerHfRunner, vllm_runner, From e4ea8613e62c955ae1f812712c4cb47e6f55f7b8 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 22:40:48 +0800 Subject: [PATCH 12/16] + tensor_parallel_size Signed-off-by: wang.yuqi --- tests/models/language/pooling/test_qwen3_reranker.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 86dbe6de011b..8c6537f3193f 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -90,7 +90,14 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: def test_rerank_models_mteb_tp(vllm_runner, model_info: RerankModelInfo) -> None: + assert model_info.architecture == "Qwen3ForSequenceClassification" + + vllm_extra_kwargs: dict[str, Any] = { + "tensor_parallel_size": 2, + } + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, + vllm_extra_kwargs, atol=1.2e-2) From 7c9c0f6c7cbef5898e09fe3378754fdbe66beb99 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 25 Aug 2025 22:46:12 +0800 Subject: [PATCH 13/16] fix registry Signed-off-by: wang.yuqi --- vllm/model_executor/models/registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 43b90f9a48ff..aacd5386e8d6 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -190,10 +190,10 @@ "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), "GteNewForSequenceClassification": ("bert_with_rope", "GteNewForSequenceClassification"), - "RobertaForSequenceClassification": ("roberta", - "RobertaForSequenceClassification"), "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), + "RobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), "XLMRobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), # [Auto-converted (see adapters.py)] From bcbf00e00395e382ed49905d3d7e8614144025bd Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 28 Aug 2025 12:10:54 +0800 Subject: [PATCH 14/16] conflicts Signed-off-by: wang.yuqi --- vllm/model_executor/models/bert_with_rope.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index bd2de885fb55..2a2a8ac7c50b 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -22,19 +22,14 @@ QKVParallelLinear, ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import (ClassifierPooler, - DispatchPooler, Pooler) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.bert import BertPooler -from vllm.model_executor.models.interfaces import (SupportsCrossEncoding, - SupportsQuant, +from vllm.model_executor.models.interfaces import (SupportsQuant, default_pooling_type) -from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, - maybe_prefix) +from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors From 112f557e2ac97a2cc54e03afc5d82fb635fc9ad7 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 28 Aug 2025 12:15:38 +0800 Subject: [PATCH 15/16] add back Signed-off-by: wang.yuqi --- vllm/model_executor/models/bert_with_rope.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index e7160bf12b8b..3be7e11d947d 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -27,12 +27,15 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.utils import WeightsMapper +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + maybe_prefix) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .interfaces import SupportsQuant +from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler +from .bert import BertPooler +from .interfaces import SupportsCrossEncoding, SupportsQuant from .interfaces_base import default_pooling_type From 3ba9300be261c934819975057d595e119906381b Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 28 Aug 2025 13:08:16 +0800 Subject: [PATCH 16/16] fix docs Signed-off-by: wang.yuqi --- docs/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index d4e4d07767ca..34e465584888 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -515,7 +515,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A ``` !!! note - The second-generation GTE model (mGTE-TRM) is named `GteNewForSequenceClassification`. The name `GteNewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture. + The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture. !!! note Load the official original `mxbai-rerank-v2` by using the following command.