Skip to content

Commit cbf1489

Browse files
DarkLight1337afeldman-nm
authored andcommitted
[Model] Replace embedding models with pooling adapter (vllm-project#10769)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
1 parent a877540 commit cbf1489

File tree

32 files changed

+387
-323
lines changed

32 files changed

+387
-323
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,6 @@ steps:
343343
commands:
344344
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
345345
- pytest -v -s models/embedding/language -m core_model
346-
- pytest -v -s models/embedding/vision_language -m core_model
347346

348347
- label: Language Models Test (Extended) # 50min
349348
optional: true
@@ -355,7 +354,6 @@ steps:
355354
commands:
356355
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
357356
- pytest -v -s models/embedding/language -m 'not core_model'
358-
- pytest -v -s models/embedding/vision_language -m 'not core_model'
359357

360358
- label: Multi-Modal Models Test (Standard) # 26min
361359
#mirror_hardwares: [amd]
@@ -368,6 +366,7 @@ steps:
368366
commands:
369367
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
370368
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
369+
- pytest -v -s models/embedding/vision_language -m core_model
371370
- pytest -v -s models/encoder_decoder/language -m core_model
372371
- pytest -v -s models/encoder_decoder/vision_language -m core_model
373372

@@ -385,6 +384,7 @@ steps:
385384
# https://github.com/huggingface/transformers/issues/34307
386385
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
387386
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
387+
- pytest -v -s models/embedding/vision_language -m 'not core_model'
388388
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
389389
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
390390

docs/source/models/supported_models.rst

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ Text Embedding
357357
- ✅︎
358358
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
359359
- Qwen2-based
360-
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
360+
- :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
361361
- ✅︎
362362
- ✅︎
363363
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
@@ -378,6 +378,10 @@ Text Embedding
378378
.. tip::
379379
You can override the model's pooling method by passing :code:`--override-pooler-config`.
380380

381+
.. note::
382+
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
383+
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.
384+
381385
.. note::
382386
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
383387
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
@@ -397,12 +401,21 @@ Reward Modeling
397401
- Example HF Models
398402
- :ref:`LoRA <lora>`
399403
- :ref:`PP <distributed_serving>`
404+
* - :code:`LlamaForCausalLM`
405+
- Llama-based
406+
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
407+
- ✅︎
408+
- ✅︎
400409
* - :code:`Qwen2ForRewardModel`
401410
- Qwen2-based
402411
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
403412
- ✅︎
404413
- ✅︎
405414

415+
.. important::
416+
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
417+
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
418+
406419
.. note::
407420
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
408421

tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,6 @@ def __init__(
263263
dtype: str = "half",
264264
*,
265265
model_kwargs: Optional[Dict[str, Any]] = None,
266-
is_embedding_model: bool = False,
267266
is_sentence_transformer: bool = False,
268267
is_cross_encoder: bool = False,
269268
skip_tokenizer_init: bool = False,

tests/models/embedding/language/test_embedding.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
"""
55
import pytest
66

7+
from vllm.config import PoolerConfig
8+
79
from ..utils import check_embeddings_close
810

911

@@ -33,6 +35,9 @@ def test_models(
3335
dtype: str,
3436
) -> None:
3537
vllm_extra_kwargs = {}
38+
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
39+
vllm_extra_kwargs["override_pooler_config"] = \
40+
PoolerConfig(pooling_type="MEAN")
3641
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
3742
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}
3843

tests/models/test_registry.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@
66
from vllm.model_executor.models import (is_embedding_model,
77
is_text_generation_model,
88
supports_multimodal)
9-
# yapf conflicts with isort for this block
10-
# yapf: disable
11-
from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS,
12-
_EMBEDDING_MODELS,
13-
_MULTIMODAL_MODELS,
9+
from vllm.model_executor.models.adapters import as_embedding_model
10+
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
1411
_SPECULATIVE_DECODING_MODELS,
1512
_TEXT_GENERATION_MODELS,
1613
ModelRegistry)
@@ -26,18 +23,18 @@ def test_registry_imports(model_arch):
2623
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
2724

2825
if model_arch in _SPECULATIVE_DECODING_MODELS:
29-
pass # Ignore these models which do not have a unified format
30-
else:
31-
assert is_text_generation_model(model_cls) is (
32-
model_arch in _TEXT_GENERATION_MODELS
33-
or model_arch in _MULTIMODAL_MODELS)
34-
35-
embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS}
36-
assert is_embedding_model(model_cls) is (model_arch
37-
in embedding_models)
38-
39-
assert supports_multimodal(model_cls) is (model_arch
40-
in _MULTIMODAL_MODELS)
26+
return # Ignore these models which do not have a unified format
27+
28+
if (model_arch in _TEXT_GENERATION_MODELS
29+
or model_arch in _MULTIMODAL_MODELS):
30+
assert is_text_generation_model(model_cls)
31+
32+
# All vLLM models should be convertible to an embedding model
33+
embed_model = as_embedding_model(model_cls)
34+
assert is_embedding_model(embed_model)
35+
36+
if model_arch in _MULTIMODAL_MODELS:
37+
assert supports_multimodal(model_cls)
4138

4239

4340
@fork_new_process_for_each_test

tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,34 @@
1-
from typing import List, Optional, Union
1+
from typing import Iterable, List, Optional, Tuple, Union
22

33
import torch
4+
import torch.nn as nn
45

56
from vllm.attention import AttentionMetadata
6-
from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel
7-
from vllm.sequence import IntermediateTensors
7+
from vllm.config import VllmConfig
8+
from vllm.model_executor.layers.pooler import Pooler, PoolingType
9+
from vllm.model_executor.models.gemma2 import Gemma2Model
10+
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
11+
from vllm.model_executor.pooling_metadata import PoolingMetadata
12+
from vllm.sequence import IntermediateTensors, PoolerOutput
813

914

10-
class MyGemma2Embedding(Gemma2EmbeddingModel):
15+
class MyGemma2Embedding(nn.Module):
16+
17+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
18+
super().__init__()
19+
20+
self.model = Gemma2Model(vllm_config=vllm_config,
21+
prefix=maybe_prefix(prefix, "model"))
22+
23+
self._pooler = Pooler.from_config_with_defaults(
24+
vllm_config.model_config.pooler_config,
25+
pooling_type=PoolingType.LAST,
26+
normalize=True,
27+
softmax=False,
28+
)
29+
30+
self.make_empty_intermediate_tensors = (
31+
self.model.make_empty_intermediate_tensors)
1132

1233
def forward(
1334
self,
@@ -18,7 +39,7 @@ def forward(
1839
intermediate_tensors: Optional[IntermediateTensors] = None,
1940
inputs_embeds: Optional[torch.Tensor] = None,
2041
) -> Union[torch.Tensor, IntermediateTensors]:
21-
hidden_states = super().forward(
42+
hidden_states = self.model(
2243
input_ids,
2344
positions,
2445
kv_caches,
@@ -32,3 +53,17 @@ def forward(
3253

3354
# Return all-zero embeddings
3455
return torch.zeros_like(hidden_states)
56+
57+
def pooler(
58+
self,
59+
hidden_states: torch.Tensor,
60+
pooling_metadata: PoolingMetadata,
61+
) -> Optional[PoolerOutput]:
62+
return self._pooler(hidden_states, pooling_metadata)
63+
64+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
65+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
66+
weights = hf_to_vllm_mapper.apply(weights)
67+
weights = ((name, data) for name, data in weights
68+
if not name.startswith("lm_head."))
69+
return self.model.load_weights(weights)

tests/test_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task):
2626

2727

2828
@pytest.mark.parametrize(("model_id", "bad_task"), [
29-
("facebook/opt-125m", "embedding"),
30-
("intfloat/e5-mistral-7b-instruct", "generate"),
29+
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
3130
])
3231
def test_incorrect_task(model_id, bad_task):
3332
with pytest.raises(ValueError, match=r"does not support the .* task"):

vllm/config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,31 @@ def _resolve_task(
370370
selected_task = next(iter(supported_tasks_lst))
371371

372372
if len(supported_tasks) > 1:
373+
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
374+
# Hardcode the models that are exceptions
375+
("AquilaModel", "generate"),
376+
("ChatGLMModel", "generate"),
377+
# Other models follow this pattern
378+
("ForCausalLM", "generate"),
379+
("ForConditionalGeneration", "generate"),
380+
("ChatModel", "generate"),
381+
("LMHeadModel", "generate"),
382+
("EmbeddingModel", "embedding"),
383+
("RewardModel", "embedding"),
384+
("ForSequenceClassification", "embedding"),
385+
]
386+
info, arch = ModelRegistry.inspect_model_cls(architectures)
387+
388+
for suffix, pref_task in suffix_to_preferred_task:
389+
if arch.endswith(suffix) and pref_task in supported_tasks:
390+
selected_task = pref_task
391+
break
392+
else:
393+
if (arch.endswith("Model")
394+
and info.architecture.endswith("ForCausalLM")
395+
and "embedding" in supported_tasks):
396+
selected_task = "embedding"
397+
373398
logger.info(
374399
"This model supports multiple tasks: %s. "
375400
"Defaulting to '%s'.", supported_tasks, selected_task)

vllm/inputs/registry.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from vllm.logger import init_logger
1212
from vllm.transformers_utils.processor import cached_get_processor
1313
from vllm.transformers_utils.tokenizer import AnyTokenizer
14-
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
15-
resolve_mm_processor_kwargs)
14+
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
15+
print_warning_once, resolve_mm_processor_kwargs)
1616

1717
from .data import ProcessorInputs, SingletonInputs
1818
from .parse import is_encoder_decoder_inputs
@@ -136,12 +136,12 @@ class InputRegistry:
136136
"""
137137

138138
def __init__(self) -> None:
139-
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
140-
DummyDataFactory] = {}
141-
self._dummy_encoder_factories_by_model_type: Dict[
142-
Type[nn.Module], DummyDataFactory] = {}
143-
self._input_processors_by_model_type: Dict[Type[nn.Module],
144-
InputProcessor] = {}
139+
self._dummy_factories_by_model_type = \
140+
ClassRegistry[nn.Module, DummyDataFactory]()
141+
self._dummy_encoder_factories_by_model_type = \
142+
ClassRegistry[nn.Module, DummyDataFactory]()
143+
self._input_processors_by_model_type = \
144+
ClassRegistry[nn.Module, InputProcessor]()
145145

146146
def _default_dummy_data_factory(
147147
self,

vllm/model_executor/layers/pooler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ def from_config_with_defaults(
6060
softmax: bool,
6161
step_tag_id: Optional[int] = None,
6262
returned_token_ids: Optional[List[int]] = None,
63-
) -> Optional["Pooler"]:
64-
if pooler_config is None:
65-
return None
63+
) -> "Pooler":
6664
return cls(
6765
pooling_type=PoolingType[pooler_config.pooling_type]
6866
if pooler_config.pooling_type is not None else pooling_type,

0 commit comments

Comments
 (0)