Skip to content

Commit f73282e

Browse files
committed
Remove unnecessary LlamaEmbeddingModel
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 666cc19 commit f73282e

File tree

3 files changed

+14
-97
lines changed

3 files changed

+14
-97
lines changed

docs/source/models/supported_models.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,12 +401,21 @@ Reward Modeling
401401
- Example HF Models
402402
- :ref:`LoRA <lora>`
403403
- :ref:`PP <distributed_serving>`
404+
* - :code:`LlamaForCausalLM`
405+
- Llama-based
406+
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
407+
- ✅︎
408+
- ✅︎
404409
* - :code:`Qwen2ForRewardModel`
405410
- Qwen2-based
406411
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
407412
- ✅︎
408413
- ✅︎
409414

415+
.. important::
416+
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
417+
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
418+
410419
.. note::
411420
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
412421

vllm/model_executor/models/llama.py

Lines changed: 3 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
QKVParallelLinear,
3838
RowParallelLinear)
3939
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40-
from vllm.model_executor.layers.pooler import Pooler, PoolingType
4140
from vllm.model_executor.layers.quantization import QuantizationConfig
4241
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
4342
get_compressed_tensors_cache_scale)
@@ -47,14 +46,12 @@
4746
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
4847
from vllm.model_executor.model_loader.weight_utils import (
4948
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
50-
from vllm.model_executor.pooling_metadata import PoolingMetadata
5149
from vllm.model_executor.sampling_metadata import SamplingMetadata
5250
from vllm.platforms import current_platform
53-
from vllm.sequence import IntermediateTensors, PoolerOutput
51+
from vllm.sequence import IntermediateTensors
5452

5553
from .interfaces import SupportsLoRA, SupportsPP
56-
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
57-
is_pp_missing_parameter,
54+
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
5855
make_empty_intermediate_tensors_factory, make_layers,
5956
maybe_prefix)
6057

@@ -497,7 +494,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
497494
config = vllm_config.model_config.hf_config
498495
quant_config = vllm_config.quant_config
499496
lora_config = vllm_config.lora_config
500-
pooler_config = vllm_config.model_config.pooler_config
501497
self.config = config
502498
self.lora_config = lora_config
503499

@@ -530,13 +526,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
530526
self.sampler = get_sampler()
531527
else:
532528
self.lm_head = PPMissingLayer()
529+
533530
self.make_empty_intermediate_tensors = (
534531
self.model.make_empty_intermediate_tensors)
535-
self._pooler = Pooler.from_config_with_defaults(
536-
pooler_config,
537-
pooling_type=PoolingType.STEP,
538-
normalize=False,
539-
softmax=False)
540532

541533
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
542534
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
@@ -567,14 +559,6 @@ def compute_logits(
567559
sampling_metadata)
568560
return logits
569561

570-
def pooler(
571-
self,
572-
hidden_states: torch.Tensor,
573-
pooling_metadata: PoolingMetadata,
574-
) -> Optional[PoolerOutput]:
575-
logits = self.compute_logits(hidden_states, None)
576-
return self._pooler(logits, pooling_metadata)
577-
578562
def sample(self, logits: torch.Tensor,
579563
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
580564
next_tokens = self.sampler(logits, sampling_metadata)
@@ -625,79 +609,3 @@ def permute(w: torch.Tensor, n_heads: int):
625609
name = name.replace(item, mapping[item])
626610

627611
return name, loaded_weight
628-
629-
630-
# TODO: Remove this once reward modeling is separated from LlamaForCausalLM
631-
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
632-
"""
633-
A model that uses Llama with additional embedding functionalities.
634-
635-
This class encapsulates the LlamaModel and provides an interface for
636-
embedding operations and customized pooling functions.
637-
638-
Attributes:
639-
model: An instance of LlamaModel used for forward operations.
640-
_pooler: An instance of Pooler used for pooling operations.
641-
"""
642-
packed_modules_mapping = {
643-
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
644-
"gate_up_proj": ["gate_proj", "up_proj"]
645-
}
646-
647-
# LoRA specific attributes
648-
supported_lora_modules = [
649-
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
650-
]
651-
embedding_modules = {
652-
"embed_tokens": "input_embeddings",
653-
}
654-
embedding_padding_modules = []
655-
656-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
657-
super().__init__()
658-
659-
pooler_config = vllm_config.model_config.pooler_config
660-
661-
self.model = LlamaModel(vllm_config=vllm_config,
662-
prefix=maybe_prefix(prefix, "model"))
663-
self._pooler = Pooler.from_config_with_defaults(
664-
pooler_config,
665-
pooling_type=PoolingType.LAST,
666-
normalize=True,
667-
softmax=False)
668-
self.make_empty_intermediate_tensors = (
669-
self.model.make_empty_intermediate_tensors)
670-
671-
def forward(
672-
self,
673-
input_ids: Optional[torch.Tensor],
674-
positions: torch.Tensor,
675-
kv_caches: List[torch.Tensor],
676-
attn_metadata: AttentionMetadata,
677-
intermediate_tensors: Optional[IntermediateTensors] = None,
678-
inputs_embeds: Optional[torch.Tensor] = None,
679-
) -> Union[torch.Tensor, IntermediateTensors]:
680-
return self.model(input_ids, positions, kv_caches, attn_metadata,
681-
intermediate_tensors, inputs_embeds)
682-
683-
def pooler(
684-
self,
685-
hidden_states: torch.Tensor,
686-
pooling_metadata: PoolingMetadata,
687-
) -> Optional[PoolerOutput]:
688-
return self._pooler(hidden_states, pooling_metadata)
689-
690-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
691-
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
692-
weights = hf_to_vllm_mapper.apply(weights)
693-
weights = ((name, data) for name, data in weights
694-
if not name.startswith("lm_head."))
695-
self.model.load_weights(weights)
696-
697-
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
698-
self.model.load_kv_cache_scales(quantization_param_path)
699-
700-
# LRUCacheWorkerLoRAManager instantiation requires model config.
701-
@property
702-
def config(self):
703-
return self.model.config

vllm/model_executor/models/registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,13 @@
110110
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
111111
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
112112
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
113-
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
113+
"LlamaModel": ("llama", "LlamaForCausalLM"),
114114
**{
115115
# Multiple models share the same architecture, so we include them all
116116
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
117117
if arch == "LlamaForCausalLM"
118118
},
119-
"MistralModel": ("llama", "LlamaEmbeddingModel"),
119+
"MistralModel": ("llama", "LlamaForCausalLM"),
120120
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
121121
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
122122
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),

0 commit comments

Comments
 (0)