Skip to content

Commit ed9c1ae

Browse files
committed
Refactor base BertModel and Roberta*Model to use AutoWeightLoader.
RobertaForSequenceClassification: - python3 -m vllm.entrypoints.cli.main serve cardiffnlp/twitter-roberta-base-sentiment-latest --served-model-name roberta-sentiment --trust-remote-code - python3 -m vllm.entrypoints.cli.main serve jinaai/jina-embeddings-v3 --served-model-name jina-v3 --trust-remote-code RobertaEmbeddingMode: - python3 -m vllm.entrypoints.cli.main serve FacebookAI/roberta-base --served-model-name roberta-base --trust-remote-code - python3 -m vllm.entrypoints.cli.main serve sentence-transformers/stsb-roberta-base-v2 --served-model-name stsb-roberta --trust-remote-code BertEmbeddingModel: - python3 -m vllm.entrypoints.cli.main serve sentence-transformers/all-MiniLM-L6-v2 --served-model-name bert-embeddings --trust-remote-code Signed-off-by: <islandhe@gmail.com> Signed-off-by: Jen H <islandhe@gmail.com>
1 parent 65b80fe commit ed9c1ae

File tree

2 files changed

+54
-73
lines changed

2 files changed

+54
-73
lines changed

vllm/model_executor/models/bert.py

Lines changed: 36 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from vllm.model_executor.layers.quantization import QuantizationConfig
2323
from vllm.model_executor.layers.vocab_parallel_embedding import (
2424
VocabParallelEmbedding)
25-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2625
from vllm.model_executor.pooling_metadata import PoolingMetadata
2726
from vllm.sequence import IntermediateTensors, PoolerOutput
2827

@@ -44,9 +43,13 @@ def __init__(self, config: BertConfig):
4443
config.type_vocab_size, config.hidden_size)
4544
self.LayerNorm = nn.LayerNorm(config.hidden_size,
4645
eps=config.layer_norm_eps)
47-
self.register_buffer(
48-
"position_ids",
49-
torch.arange(config.max_position_embeddings).expand((1, -1)))
46+
47+
# Use nn.Parameter with requires_grad=False to maintain compatibility
48+
# with existing HF checkpoints while ensuring position_ids are
49+
# non-trainable.
50+
self.position_ids = nn.Parameter(torch.empty(
51+
(1, config.max_position_embeddings)),
52+
requires_grad=False)
5053

5154
self.position_embedding_type = config.position_embedding_type
5255
if self.position_embedding_type != "absolute":
@@ -359,45 +362,44 @@ def load_weights(self, weights: Iterable[tuple[str,
359362
("qkv_proj", "value", "v"),
360363
]
361364

365+
loaded_stacked_params = []
366+
other_weights = []
362367
params_dict = dict(self.named_parameters())
363-
loaded_params: set[str] = set()
364368
for name, loaded_weight in weights:
365-
if self.pooler is None and "pooler" in name:
366-
continue
367369
for (param_name, weight_name, shard_id) in stacked_params_mapping:
368370
if weight_name not in name:
369371
continue
372+
370373
name = name.replace(weight_name, param_name)
371-
# Skip loading extra bias for GPTQ models.
372-
if name.endswith(".bias") and name not in params_dict:
374+
if name not in params_dict:
373375
continue
374376
param = params_dict[name]
375377
weight_loader = param.weight_loader
376378
weight_loader(param, loaded_weight, shard_id)
379+
loaded_stacked_params.append(name)
377380
break
378381
else:
379-
# Skip loading extra bias for GPTQ models.
380-
if name.endswith(".bias") and name not in params_dict:
381-
continue
382-
param = params_dict[name]
383-
weight_loader = getattr(param, "weight_loader",
384-
default_weight_loader)
385-
weight_loader(param, loaded_weight)
386-
loaded_params.add(name)
382+
other_weights.append((name, loaded_weight))
383+
384+
loader = AutoWeightsLoader(
385+
self,
386+
skip_prefixes=(["pooler."] if self.pooler is None else []),
387+
)
388+
loaded_params = loader.load_weights(other_weights)
389+
loaded_params.update(loaded_stacked_params)
387390
return loaded_params
388391

389392

390393
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
391394
"""A model that uses Bert to provide embedding functionalities.
392395
393-
This class encapsulates the BertModel and provides an interface for
394-
embedding operations and customized pooling functions.
396+
This class encapsulates the BertModel and provides an interface for
397+
embedding operations and customized pooling functions.
395398
396-
Attributes:
397-
model: An instance of BertModel used for forward operations.
398-
_pooler: An instance of Pooler used for pooling operations.
399-
"""
400-
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
399+
Attributes:
400+
model: An instance of BertModel used for forward operations.
401+
_pooler: An instance of Pooler used for pooling operations.
402+
"""
401403

402404
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
403405
super().__init__()
@@ -426,10 +428,15 @@ def pooler(
426428
return self._pooler(hidden_states, pooling_metadata)
427429

428430
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
429-
weights = self.hf_to_vllm_mapper.apply(weights)
430-
weights = ((name, data) for name, data in weights
431-
if not name.startswith("lm_head."))
432-
self.model.load_weights(weights)
431+
weights_list = list(weights)
432+
433+
has_model_prefix = any(
434+
name.startswith("model.") for name, _ in weights_list)
435+
if not has_model_prefix:
436+
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
437+
438+
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
439+
return loader.load_weights(weights_list, mapper=mapper)
433440

434441
def _build_model(self,
435442
vllm_config: VllmConfig,
@@ -471,27 +478,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
471478
self.classifier, self.bert.pooler)
472479

473480
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
474-
bert_weights = []
475-
classifier_weights = []
476-
477-
for name, weight in weights:
478-
if name.startswith("bert."):
479-
bert_weights.append((name, weight))
480-
else:
481-
classifier_weights.append((name, weight))
482-
483481
loader = AutoWeightsLoader(self)
484-
loaded_params = loader.load_weights(bert_weights)
485-
486-
params_dict = dict(self.named_parameters())
487-
for name, loaded_weight in classifier_weights:
488-
if name in params_dict:
489-
param = params_dict[name]
490-
weight_loader = getattr(param, "weight_loader",
491-
default_weight_loader)
492-
weight_loader(param, loaded_weight)
493-
loaded_params.add(name)
494-
482+
loaded_params = loader.load_weights(weights)
495483
return loaded_params
496484

497485
def pooler(

vllm/model_executor/models/roberta.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from vllm.model_executor.layers.pooler import ClassifierPooler
1414
from vllm.model_executor.layers.vocab_parallel_embedding import (
1515
VocabParallelEmbedding)
16-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1716
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
18-
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
17+
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
18+
maybe_prefix)
1919
from vllm.model_executor.pooling_metadata import PoolingMetadata
2020
from vllm.sequence import IntermediateTensors, PoolerOutput
2121

@@ -136,16 +136,20 @@ def _build_model(self,
136136
embedding_class=RobertaEmbedding)
137137

138138
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
139-
weights = self.hf_to_vllm_mapper.apply(weights)
140-
# Separate weights in "roberta"-prefixed and all else (not in memory).
141-
# For use with models like FacebookAI/roberta-base.
142-
bert_weights, task_weights = roberta_task_weights_filter(weights)
143-
loaded = self.model.load_weights(bert_weights)
144-
if not len(loaded):
145-
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
146-
# which use the same architecture, but have no "roberta" prefix.
147-
loaded = self.model.load_weights(task_weights)
148-
assert len(loaded), "Unable to load RobertaEmbeddingModel"
139+
weights_list = list(weights)
140+
has_roberta_prefix = any(
141+
name.startswith("roberta.") for name, _ in weights_list)
142+
if has_roberta_prefix:
143+
# For models with the `roberta.` prefix e.g.
144+
# `FacebookAI/roberta-base`
145+
mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
146+
else:
147+
# For models without the `roberta.` prefix e.g.
148+
# `sentence-transformers/stsb-roberta-base-v2`
149+
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
150+
151+
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
152+
return loader.load_weights(weights_list, mapper=mapper)
149153

150154

151155
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
@@ -187,19 +191,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
187191
self.classifier)
188192

189193
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
190-
bert_weights, task_weights = roberta_task_weights_filter(weights)
191-
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
192-
193-
self.roberta.load_weights(bert_weights)
194-
195-
params_dict = dict(self.named_parameters())
196-
197-
for name, loaded_weight in task_weights:
198-
if name.startswith("classifier"):
199-
param = params_dict[name]
200-
weight_loader = getattr(param, "weight_loader",
201-
default_weight_loader)
202-
weight_loader(param, loaded_weight)
194+
loader = AutoWeightsLoader(self)
195+
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
203196

204197
def pooler(
205198
self,

0 commit comments

Comments
 (0)