From a7fe14f53ff18259cbcd0e971b4917d28b5c1669 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Wed, 30 Jul 2025 19:48:02 +0800 Subject: [PATCH 1/8] v1 support Signed-off-by: wang.yuqi --- .../language/pooling/test_classification.py | 8 -------- tests/models/language/pooling/test_gte.py | 16 ++++------------ tests/models/language/pooling/test_jina.py | 8 -------- .../language/pooling/test_qwen3_reranker.py | 6 ------ vllm/config.py | 8 ++++++++ vllm/model_executor/models/bert.py | 2 -- vllm/model_executor/models/config.py | 2 +- vllm/model_executor/models/modernbert.py | 2 -- 8 files changed, 13 insertions(+), 39 deletions(-) diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 77df6d16a367..c71fa9627533 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -6,14 +6,6 @@ from vllm.platforms import current_platform -# TODO: enable when float32 is supported by V1 -# @pytest.fixture(autouse=True) -# def v1(run_with_both_engines): -# # Simple autouse wrapper to run both engines for each test -# # This can be promoted up to conftest.py to run for every -# # test in a package -# pass - @pytest.mark.parametrize( "model", diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 0ad54785308e..6a3a0f150b6d 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -56,16 +56,10 @@ enable_test=False), ] -V1FlashAttentionImpNotSupported = [ - "Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-modernbert-base" -] - @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo, - monkeypatch) -> None: - if model_info.name in V1FlashAttentionImpNotSupported: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: vllm_extra_kwargs: dict[str, Any] = {} if model_info.architecture == "GteNewModel": @@ -77,10 +71,8 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo, @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, example_prompts, - monkeypatch) -> None: - if model_info.name in V1FlashAttentionImpNotSupported: - monkeypatch.setenv("VLLM_USE_V1", "0") + model_info: EmbedModelInfo, + example_prompts) -> None: vllm_extra_kwargs: dict[str, Any] = {} if model_info.architecture == "GteNewModel": diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 2ae431de1683..c982952520d8 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -24,14 +24,6 @@ ] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 9c6a833b4138..68e96f32700c 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -83,9 +83,6 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: } } - if model_info.name == "Qwen/Qwen3-Reranker-4B": - vllm_extra_kwargs["max_num_seqs"] = 1 - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs) @@ -106,9 +103,6 @@ def test_rerank_models_mteb_tp(vllm_runner, "tensor_parallel_size": 2, } - if model_info.name == "Qwen/Qwen3-Reranker-4B": - vllm_extra_kwargs["max_num_seqs"] = 1 - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, diff --git a/vllm/config.py b/vllm/config.py index 8e8c1198833c..5ca66e80dc43 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -776,6 +776,9 @@ def _task_to_convert(task: TaskOption) -> ConvertType: raise ValueError( "`override_neuron_config` is only supported on Neuron.") + # Avoid running try_verify_and_update_config multiple times + self.config_updated = False + self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -4912,6 +4915,11 @@ def try_verify_and_update_config(self): if self.model_config is None: return + # Avoid running try_verify_and_update_config multiple times + if getattr(self.model_config, "config_updated", False): + return + self.model_config.config_updated = True + architecture = self.model_config.architecture if architecture is None: return diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 504621c8abd8..22d5cc3a6bac 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -9,7 +9,6 @@ from transformers import BertConfig from vllm.attention import Attention, AttentionType -from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -334,7 +333,6 @@ def forward(self, hidden_states: torch.Tensor, return hidden_states -@support_torch_compile class BertModel(nn.Module, SupportsQuant): is_pooling_model = True diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 6f50b1753098..9030ff307bee 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -93,7 +93,7 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: config.num_hidden_layers = config.n_layer head_dim = config.hidden_size // config.num_attention_heads - rotary_emb_dim = head_dim * config.rotary_emb_fraction + rotary_emb_dim = int(head_dim * config.rotary_emb_fraction) max_trained_positions = getattr(config, "max_trained_positions", 2048) config.rotary_kwargs = { "head_size": head_dim, diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index fc2b0c1f5182..4967032a244e 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -8,7 +8,6 @@ from transformers import ModernBertConfig from vllm.attention import Attention, AttentionType -from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -200,7 +199,6 @@ def forward( return hidden_states -@support_torch_compile class ModernBertModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"layers.": "encoder_layer.layers."}) From 5bae8c44716200060e575b6f0a199e2c8d91e963 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Wed, 30 Jul 2025 19:57:57 +0800 Subject: [PATCH 2/8] fix Signed-off-by: wang.yuqi --- vllm/model_executor/models/bert_with_rope.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 5249acbd84a5..3ee0e664e838 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -8,7 +8,6 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionType -from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -360,7 +359,6 @@ def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor): return hidden_states -@support_torch_compile class BertWithRopeEncoder(nn.Module): def __init__(self, From fa6f053070024f64e113a7c7a9dfc3ac22e19e4a Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Wed, 30 Jul 2025 20:05:07 +0800 Subject: [PATCH 3/8] fix Signed-off-by: wang.yuqi --- tests/models/language/pooling/test_jina.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index c982952520d8..79538962882b 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -4,7 +4,6 @@ import pytest -import vllm.envs as envs from vllm import PoolingParams from ...utils import EmbedModelInfo, RerankModelInfo @@ -55,9 +54,6 @@ def hf_model_callback(model): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(hf_runner, vllm_runner, model_info: RerankModelInfo) -> None: - if (model_info.architecture == "XLMRobertaForSequenceClassification" - and envs.VLLM_USE_V1): - pytest.skip("Not supported yet") mteb_test_rerank_models(hf_runner, vllm_runner, model_info) From e8ce8daeeebfcfab034fe7c2e7a41925414c0abc Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 31 Jul 2025 13:27:40 +0800 Subject: [PATCH 4/8] fix Signed-off-by: wang.yuqi --- tests/models/language/pooling/test_jina.py | 1 - vllm/model_executor/models/bert.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 79538962882b..59b634428cef 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -54,7 +54,6 @@ def hf_model_callback(model): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(hf_runner, vllm_runner, model_info: RerankModelInfo) -> None: - mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 22d5cc3a6bac..54fa5588bece 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -29,6 +29,7 @@ from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix +from ...compilation.decorators import support_torch_compile class BertEmbedding(nn.Module): @@ -333,6 +334,7 @@ def forward(self, hidden_states: torch.Tensor, return hidden_states +@support_torch_compile class BertModel(nn.Module, SupportsQuant): is_pooling_model = True From c1e394d199cbecd45cc8a48b59dc72303af2a3f2 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 31 Jul 2025 13:29:23 +0800 Subject: [PATCH 5/8] fix Signed-off-by: wang.yuqi --- tests/models/language/pooling/test_gte.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 6a3a0f150b6d..6d2eff709961 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -60,7 +60,6 @@ @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: - vllm_extra_kwargs: dict[str, Any] = {} if model_info.architecture == "GteNewModel": vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} @@ -73,7 +72,6 @@ def test_embed_models_mteb(hf_runner, vllm_runner, def test_embed_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts) -> None: - vllm_extra_kwargs: dict[str, Any] = {} if model_info.architecture == "GteNewModel": vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} From 5b8ff6188fe5e687f522e518f4525b24f2f52b21 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 31 Jul 2025 13:32:40 +0800 Subject: [PATCH 6/8] fix Signed-off-by: wang.yuqi --- vllm/model_executor/models/bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 54fa5588bece..ce5034398d35 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -27,9 +27,9 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask +from ...compilation.decorators import support_torch_compile from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix -from ...compilation.decorators import support_torch_compile class BertEmbedding(nn.Module): From 0edc5b8f0bf8c1e62f6082ff6af87a0cc101c866 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 31 Jul 2025 13:34:04 +0800 Subject: [PATCH 7/8] fix Signed-off-by: wang.yuqi --- vllm/model_executor/models/bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index ce5034398d35..504621c8abd8 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -9,6 +9,7 @@ from transformers import BertConfig from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -27,7 +28,6 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from ...compilation.decorators import support_torch_compile from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix From 2766a4bac0df120f30808b423b9141486d79aa4e Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 31 Jul 2025 13:37:50 +0800 Subject: [PATCH 8/8] fix Signed-off-by: wang.yuqi --- vllm/model_executor/models/bert_with_rope.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 3ee0e664e838..59033cb74a33 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models import SupportsV0Only from vllm.model_executor.models.interfaces import SupportsQuant from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import set_weight_attrs @@ -392,7 +391,7 @@ def forward( return hidden_states -class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): +class BertWithRope(nn.Module, SupportsQuant): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):