diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 157fa8d68de5..aece0022d940 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -581,7 +581,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) | `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ | !!! note - Named Entity Recognition (NER) usage, please refer to , . + Named Entity Recognition (NER) usage, please refer to , . [](){ #supported-mm-models } diff --git a/examples/online_serving/pooling/README.md b/examples/online_serving/pooling/README.md index 2c271b6a32bc..82e91d3129f0 100644 --- a/examples/online_serving/pooling/README.md +++ b/examples/online_serving/pooling/README.md @@ -15,7 +15,7 @@ python examples/online_serving/pooling/jinaai_rerank_client.py ## Named Entity Recognition (NER) usage ```bash -python examples/online_serving/pooling/ner.py +python examples/online_serving/pooling/ner_client.py ``` ## Openai chat embedding for multimodal usage diff --git a/examples/online_serving/pooling/ner.py b/examples/online_serving/pooling/ner_client.py similarity index 100% rename from examples/online_serving/pooling/ner.py rename to examples/online_serving/pooling/ner_client.py diff --git a/tests/ci_envs.py b/tests/ci_envs.py index 596a05b9e5f3..f3a54f308cd8 100644 --- a/tests/ci_envs.py +++ b/tests/ci_envs.py @@ -8,6 +8,8 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any +from vllm.envs import maybe_convert_bool + if TYPE_CHECKING: VLLM_CI_NO_SKIP: bool = False VLLM_CI_DTYPE: str | None = None @@ -25,6 +27,10 @@ "VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None), # Allow changing the head dtype used by transformers in tests "VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None), + # Allow control over whether tests use enforce_eager + "VLLM_CI_ENFORCE_EAGER": lambda: maybe_convert_bool( + os.getenv("VLLM_CI_ENFORCE_EAGER", None) + ), } diff --git a/tests/entrypoints/pooling/llm/test_classify.py b/tests/entrypoints/pooling/llm/test_classify.py index ae216c464a5b..488c82c9fe7f 100644 --- a/tests/entrypoints/pooling/llm/test_classify.py +++ b/tests/entrypoints/pooling/llm/test_classify.py @@ -58,7 +58,9 @@ def get_outputs(activation): ) +@pytest.mark.skip_global_cleanup def test_encode_api(llm: LLM): + # chunked prefill does not support all pooling err_msg = "pooling_task must be one of.+" with pytest.raises(ValueError, match=err_msg): llm.encode(prompts, use_tqdm=False) diff --git a/tests/entrypoints/pooling/llm/test_embedding.py b/tests/entrypoints/pooling/llm/test_embedding.py index aa24a70fd18b..c53941390bd1 100644 --- a/tests/entrypoints/pooling/llm/test_embedding.py +++ b/tests/entrypoints/pooling/llm/test_embedding.py @@ -35,7 +35,6 @@ def llm(): cleanup_dist_env_and_memory() -@pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): def get_outputs(normalize): outputs = llm.embed( diff --git a/tests/entrypoints/pooling/llm/test_encode.py b/tests/entrypoints/pooling/llm/test_encode.py index d6aae99944f8..9ba380334e5a 100644 --- a/tests/entrypoints/pooling/llm/test_encode.py +++ b/tests/entrypoints/pooling/llm/test_encode.py @@ -74,7 +74,6 @@ def test_multiple_pooling_params(llm: LLM): assert len(PROMPTS) == len(outputs) -@pytest.mark.skip_global_cleanup def test_right_side_truncation(llm: LLM): # Embeddings models should truncate the end of the prompt tokenizer = llm.get_tokenizer() diff --git a/tests/entrypoints/pooling/llm/test_score.py b/tests/entrypoints/pooling/llm/test_score.py index 9bf74fce906b..2df973dd7863 100644 --- a/tests/entrypoints/pooling/llm/test_score.py +++ b/tests/entrypoints/pooling/llm/test_score.py @@ -33,7 +33,6 @@ def llm(): cleanup_dist_env_and_memory() -@pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): def get_outputs(activation): text_1 = "What is the capital of France?" diff --git a/tests/models/language/generation_ppl_test/ppl_utils.py b/tests/models/language/generation_ppl_test/ppl_utils.py index cfa09635effc..59740505e827 100644 --- a/tests/models/language/generation_ppl_test/ppl_utils.py +++ b/tests/models/language/generation_ppl_test/ppl_utils.py @@ -3,12 +3,15 @@ # Adapted from https://huggingface.co/docs/transformers/perplexity from typing import cast -import pytest import torch from datasets import load_dataset import tests.ci_envs as ci_envs -from tests.models.utils import GenerateModelInfo, TokensTextLogprobsPromptLogprobs +from tests.models.utils import ( + GenerateModelInfo, + TokensTextLogprobsPromptLogprobs, + get_vllm_extra_kwargs, +) from vllm.logprobs import Logprob # See #24485 @@ -25,27 +28,10 @@ def wikitext_ppl_test( vllm_extra_kwargs=None, atol=PPL_TOL, ): - # A model family has many models with the same architecture, - # and we don't need to test each one. - if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: - pytest.skip("Skipping test.") + vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs) dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - # Allow vllm to test using the given dtype, such as float32 - vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype - - # Allow vllm to test using hf_overrides - if model_info.hf_overrides is not None: - vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides - - # Allow changing the head dtype used by vllm in tests - if ci_envs.VLLM_CI_HEAD_DTYPE is not None: - if "hf_overrides" not in vllm_extra_kwargs: - vllm_extra_kwargs["hf_overrides"] = {} - vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE - with vllm_runner( model_info.name, gpu_memory_utilization=0.7, diff --git a/tests/models/language/pooling/test_head_dtype.py b/tests/models/language/pooling/test_head_dtype.py new file mode 100644 index 000000000000..b60d4dade49a --- /dev/null +++ b/tests/models/language/pooling/test_head_dtype.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForSequenceClassification + + +@pytest.mark.parametrize( + "model", + ["nie3e/sentiment-polish-gpt2-small"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForSequenceClassification + ) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for head_dtype_str in ["float32", "model"]: + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + hf_overrides={"head_dtype": head_dtype_str}, + ) as vllm_model: + model_config = vllm_model.llm.llm_engine.model_config + model_dtype = model_config.dtype + head_dtype = model_config.head_dtype + + if head_dtype_str == "float32": + assert head_dtype == torch.float32 + elif head_dtype_str == "model": + assert head_dtype == model_dtype + + vllm_outputs = vllm_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output).float() + vllm_output = torch.tensor(vllm_output).float() + + assert torch.allclose(hf_output, vllm_output, atol=1e-2) diff --git a/tests/models/language/pooling/test_splade_sparse_pooler.py b/tests/models/language/pooling/test_splade_sparse_pooler.py index 636a6f2f9d74..af4fd764ef53 100644 --- a/tests/models/language/pooling/test_splade_sparse_pooler.py +++ b/tests/models/language/pooling/test_splade_sparse_pooler.py @@ -3,7 +3,6 @@ import types -import numpy as np import pytest import torch import torch.nn as nn @@ -14,11 +13,12 @@ ) # --------------------------------------------------------------------- -# 1) Functional test: SPLADE formula correctness (no HF download needed) +# Functional test: SPLADE formula correctness (no HF download needed) # --------------------------------------------------------------------- @pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)]) +@torch.inference_mode def test_splade_pooler_matches_reference_formula(B, T, H, V): """Ensure SPLADESparsePooler forward() matches the mathematical formula: log1p(relu(logits)) -> max over sequence length (after masking).""" @@ -26,9 +26,11 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V): # Prepare [B] sequences of shape [T, H] hs_list = [torch.randn(T, H) for _ in range(B)] + hs_tenser = torch.cat(hs_list) # Simulate PoolingMetadata (only required fields) prompt_lens = [T, T - 1] + prompt_lens_tenser = torch.tensor(prompt_lens, dtype=torch.int32) token_ids = torch.tensor( [ [101, 5, 102], # Batch 0: [CLS], token, [SEP] @@ -36,7 +38,9 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V): ], dtype=torch.long, ) - meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids) + meta = types.SimpleNamespace( + prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids + ) # MLM head (prefer BertMLMHead, fallback to Linear if unavailable) try: @@ -46,10 +50,10 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V): # Forward pass through SPLADE pooler pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True) - pooled = pooler(hidden_states=hs_list, pooling_metadata=meta) # list of [V] + pooled = pooler(hidden_states=hs_tenser, pooling_metadata=meta) # list of [V] # Basic output checks - assert isinstance(pooled, list) and len(pooled) == B + assert isinstance(pooled, torch.Tensor) and len(pooled) == B for vec in pooled: assert vec.shape == (V,) assert torch.isfinite(vec).all() @@ -83,40 +87,3 @@ def ref_one(hs: torch.Tensor, L: int, tid_row: torch.Tensor) -> torch.Tensor: rtol=1e-4, atol=1e-4, ) - - -# --------------------------------------------------------------------- -# 2) Integration smoke test: end-to-end embedding path wiring -# --------------------------------------------------------------------- - - -@pytest.mark.cpu_model -def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch): - """Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings.""" - from transformers import AutoTokenizer - - MODEL_ID = "hf-internal-testing/tiny-random-bert" - hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]} - - # Enforce CPU-only execution (optional) - monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "") - monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - - tok = AutoTokenizer.from_pretrained(MODEL_ID) - vocab_size = tok.vocab_size - - # The embed path should route through SPLADESparsePooler - with vllm_runner( - MODEL_ID, - runner="pooling", - max_model_len=64, - hf_overrides=hf_overrides, - ) as vm: - outs = vm.embed(["hello world", "splade sparse test"]) - - # Basic sanity checks - assert len(outs) == 2 - assert outs[0].shape[0] == vocab_size - assert outs[1].shape[0] == vocab_size - assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all() - assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all() diff --git a/tests/models/language/pooling_mteb_test/mteb_utils.py b/tests/models/language/pooling_mteb_test/mteb_utils.py index f2a817737749..0384ff82790f 100644 --- a/tests/models/language/pooling_mteb_test/mteb_utils.py +++ b/tests/models/language/pooling_mteb_test/mteb_utils.py @@ -6,12 +6,16 @@ import mteb import numpy as np -import pytest import requests import torch import tests.ci_envs as ci_envs -from tests.models.utils import EmbedModelInfo, RerankModelInfo, check_embeddings_close +from tests.models.utils import ( + EmbedModelInfo, + RerankModelInfo, + check_embeddings_close, + get_vllm_extra_kwargs, +) # Most embedding models on the STS12 task (See #17175): # - Model implementation and minor changes in tensor dtype @@ -165,28 +169,11 @@ def mteb_test_embed_models( hf_model_callback=None, atol=MTEB_EMBED_TOL, ): - # A model family has many models with the same architecture, - # and we don't need to test each one. - if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: - pytest.skip("Skipping test.") + vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs) # Test embed_dims, isnan and whether to use normalize example_prompts = ["The chef prepared a delicious meal." * 1000] - # Allow vllm to test using the given dtype, such as float32 - vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype - - # Allow vllm to test using hf_overrides - if model_info.hf_overrides is not None: - vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides - - # Allow changing the head dtype used by vllm in tests - if ci_envs.VLLM_CI_HEAD_DTYPE is not None: - if "hf_overrides" not in vllm_extra_kwargs: - vllm_extra_kwargs["hf_overrides"] = {} - vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE - with vllm_runner( model_info.name, runner="pooling", @@ -212,9 +199,12 @@ def mteb_test_embed_models( vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype head_dtype = model_config.head_dtype - # Test embed_dims, isnan and whether to use normalize + # Test embedding_size, isnan and whether to use normalize vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1) - assert not torch.any(torch.isnan(torch.tensor(vllm_outputs))) + outputs_tensor = torch.tensor(vllm_outputs) + assert not torch.any(torch.isnan(outputs_tensor)) + embedding_size = model_config.embedding_size + assert torch.tensor(vllm_outputs).shape[-1] == embedding_size # Accelerate mteb test by setting # SentenceTransformers mteb score to a constant @@ -231,7 +221,7 @@ def mteb_test_embed_models( st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) st_dtype = next(hf_model.model.parameters()).dtype - # Test embed_dims and whether to use normalize + # Check embeddings close to hf outputs hf_outputs = hf_model.encode(example_prompts) check_embeddings_close( embeddings_0_lst=hf_outputs, @@ -323,24 +313,7 @@ def mteb_test_rerank_models( vllm_mteb_encoder=VllmMtebEncoder, atol=MTEB_RERANK_TOL, ): - # A model family has many models with the same architecture, - # and we don't need to test each one. - if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: - pytest.skip("Skipping test.") - - # Allow vllm to test using the given dtype, such as float32 - vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype - - # Allow vllm to test using hf_overrides - if model_info.hf_overrides is not None: - vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides - - # Allow changing the head dtype used by vllm in tests - if ci_envs.VLLM_CI_HEAD_DTYPE is not None: - if "hf_overrides" not in vllm_extra_kwargs: - vllm_extra_kwargs["hf_overrides"] = {} - vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs) with vllm_runner( model_info.name, diff --git a/tests/models/utils.py b/tests/models/utils.py index 82da4aa64921..f5c16b3c6542 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -15,6 +15,7 @@ from vllm.multimodal.processing import InputProcessingContext from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from .. import ci_envs from .registry import HF_EXAMPLE_MODELS TokensText = tuple[list[int], str] @@ -414,6 +415,35 @@ class GenerateModelInfo(ModelInfo): hf_ppl: float | None = None +def get_vllm_extra_kwargs(model_info: ModelInfo, vllm_extra_kwargs): + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: + import pytest + + pytest.skip("Skipping test.") + + # Allow vllm to test using the given dtype, such as float32 + vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + + # Allow vllm to test using hf_overrides + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + + # Allow control over whether tests use enforce_eager + if ci_envs.VLLM_CI_ENFORCE_EAGER is not None: + vllm_extra_kwargs["enforce_eager"] = ci_envs.VLLM_CI_ENFORCE_EAGER + + return vllm_extra_kwargs + + def dummy_hf_overrides( hf_config: PretrainedConfig, *, diff --git a/vllm/config/model.py b/vllm/config/model.py index a2dcf5210754..0069dc6cca94 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -30,6 +30,7 @@ get_sentence_transformer_tokenizer_config, is_encoder_decoder, is_interleaved, + try_get_dense_modules, try_get_generation_config, try_get_safetensors_metadata, try_get_tokenizer_config, @@ -1681,6 +1682,20 @@ def head_dtype(self) -> torch.dtype: logger.debug_once("head dtype: %s", head_dtype) return head_dtype + @property + def hidden_size(self): + if hasattr(self.hf_config, "hidden_size"): + return self.hf_config.hidden_size + text_config = self.hf_config.get_text_config() + return text_config.hidden_size + + @property + def embedding_size(self): + dense_modules = try_get_dense_modules(self.model, revision=self.revision) + if dense_modules is not None: + return dense_modules[-1]["out_features"] + return self.hidden_size + def get_and_verify_max_len(self, max_model_len: int): # Consider max_model_len in tokenizer_config only when # pooling models use absolute position_embedding. diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 6d035f93dd9b..1d3874b16484 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -13,7 +13,10 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig -from vllm.transformers_utils.config import get_hf_file_bytes, get_hf_file_to_dict +from vllm.transformers_utils.config import ( + get_hf_file_bytes, + try_get_dense_modules, +) from .interfaces_base import VllmModelForPooling, is_pooling_model @@ -35,43 +38,25 @@ def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None: """Load Sentence-Transformers Dense projection layers.""" - try: - modules = get_hf_file_to_dict( - "modules.json", model_config.model, model_config.revision - ) - if not modules: - return None - - if isinstance(modules, dict): - modules = modules.get("modules", []) + dense_modules = try_get_dense_modules( + model_config.model, revision=model_config.revision + ) - dense_modules = [ - m for m in modules if m.get("type") == "sentence_transformers.models.Dense" - ] - if not dense_modules: - return None + if dense_modules is None: + return + try: layers = [] - for module in dense_modules: - folder = module.get("path", "") - - config_path = f"{folder}/config.json" if folder else "config.json" - layer_config = get_hf_file_to_dict( - config_path, model_config.model, model_config.revision - ) - if not layer_config: - continue - + for layer_config in dense_modules: + folder = layer_config["folder"] linear = nn.Linear( - layer_config.get("in_features", 768), - layer_config.get("out_features", 768), + layer_config["in_features"], + layer_config["out_features"], bias=layer_config.get("bias", True), dtype=model_config.head_dtype, ) - if not _load_dense_weights(linear, folder, model_config): continue - layers.append(linear) if act_name := layer_config.get("activation_function"): layers.append(get_act_fn(act_name)) @@ -303,18 +288,18 @@ def as_seq_cls_model(cls: _T) -> _T: from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.sequence import IntermediateTensors - from .utils import get_model_hidden_size, maybe_prefix + from .utils import maybe_prefix class ModelForSequenceClassification( _create_pooling_model_cls(cls), SupportsCrossEncoding ): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config quant_config = vllm_config.quant_config - hidden_size = get_model_hidden_size(config) self.score = ReplicatedLinear( - hidden_size, + model_config.hidden_size, config.num_labels, bias=False, params_dtype=torch.float32, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 4cafe724f1ca..ddd6e53b4a43 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors from ..layers.pooler import DispatchPooler, Pooler -from .interfaces import SupportsPP +from .interfaces import SupportsCrossEncoding, SupportsPP from .utils import ( AutoWeightsLoader, is_pp_missing_parameter, @@ -321,7 +321,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loader.load_weights(weights) -class GPT2ForSequenceClassification(nn.Module): +class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding): """GPT2 Model for sequence classification. This class expands GPT2Model with pooling and score functions - last token @@ -358,6 +358,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): } ) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 756a3900965b..ede3e34881b1 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -148,37 +148,6 @@ def get_supported_tasks(self) -> Set[PoolingTask]: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return PoolingParamsUpdate(requires_token_ids=True) - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: torch.Tensor | None = None, - instr_len: torch.Tensor | None = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], ( - "partial prefill not supported with MEAN pooling" - ) - - return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32) - - def forward_all( - self, - hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, - instr_lens: torch.Tensor, - ) -> list[torch.Tensor] | torch.Tensor: - offset = 0 - pooled_data = list[torch.Tensor]() - - for prompt_len, instr_len in zip(prompt_lens, instr_lens): - pooled_data.append( - hidden_states[offset + instr_len : offset + prompt_len].mean( - dim=0, dtype=torch.float32 - ) - ) - offset += prompt_len - - return pooled_data - def forward( self, hidden_states: torch.Tensor | list[torch.Tensor], @@ -190,18 +159,20 @@ def forward( self._get_instruction_len(token_ids.cpu().numpy()) for token_ids in get_prompt_token_ids(pooling_metadata) ], - device=prompt_lens.device, + device="cpu", ) - if isinstance(hidden_states, list): - return [ - self.forward_one(h, prompt_len, instr_len) - for h, prompt_len, instr_len in zip( - hidden_states, prompt_lens, instr_lens + offset = 0 + pooled_data = list[torch.Tensor]() + for prompt_len, instr_len in zip(prompt_lens, instr_lens): + pooled_data.append( + hidden_states[offset + instr_len : offset + prompt_len].mean( + dim=0, dtype=torch.float32 ) - ] + ) + offset += prompt_len - return self.forward_all(hidden_states, prompt_lens, instr_lens) + return pooled_data class GritLMPooler(Pooler): diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 8812ed177f56..71abfe98813d 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -777,13 +777,6 @@ def fast_topk( return torch.topk(values, topk, dim=dim) -def get_model_hidden_size(hf_config: PretrainedConfig) -> int: - if hasattr(hf_config, "hidden_size"): - return hf_config.hidden_size - text_config = hf_config.get_text_config() - return text_config.hidden_size - - # Chunk x along the num_tokens axis for sequence parallelism # NOTE: This is wrapped in a torch custom op to work around the following issue: # The output tensor can have a sequence length 0 at small input sequence lengths diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 58e0b5314602..623e17b05a6e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1049,6 +1049,40 @@ def try_get_tokenizer_config( return None +@cache +def try_get_dense_modules( + model: str | Path, + revision: str | None = None, +) -> list[dict[str, Any]] | None: + try: + modules = get_hf_file_to_dict("modules.json", model, revision) + if not modules: + return None + + if isinstance(modules, dict): + modules = modules.get("modules", []) + + dense_modules = [ + m for m in modules if m.get("type") == "sentence_transformers.models.Dense" + ] + if not dense_modules: + return None + + layer_configs = [] + for module in dense_modules: + folder = module.get("path", "") + + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict(config_path, model, revision) + if not layer_config: + continue + layer_config["folder"] = folder + layer_configs.append(layer_config) + return layer_configs + except Exception: + return None + + def get_safetensors_params_metadata( model: str, *,