Skip to content

Commit 65b80fe

Browse files
committed
[Model] Use AutoWeightsLoader for BERT and fix position_ids.
Signed-off-by: Jennifer He <islandhe@gmail.com>
1 parent c18b3b8 commit 65b80fe

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

vllm/model_executor/models/bert.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm.sequence import IntermediateTensors, PoolerOutput
2828

2929
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
30-
from .utils import WeightsMapper, maybe_prefix
30+
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
3131

3232

3333
class BertEmbedding(nn.Module):
@@ -44,8 +44,9 @@ def __init__(self, config: BertConfig):
4444
config.type_vocab_size, config.hidden_size)
4545
self.LayerNorm = nn.LayerNorm(config.hidden_size,
4646
eps=config.layer_norm_eps)
47-
self.position_ids = nn.Parameter(
48-
torch.empty((1, config.max_position_embeddings)), )
47+
self.register_buffer(
48+
"position_ids",
49+
torch.arange(config.max_position_embeddings).expand((1, -1)))
4950

5051
self.position_embedding_type = config.position_embedding_type
5152
if self.position_embedding_type != "absolute":
@@ -470,26 +471,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
470471
self.classifier, self.bert.pooler)
471472

472473
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
474+
bert_weights = []
475+
classifier_weights = []
473476

474-
self_weights = []
475-
476-
def weight_filter():
477-
for name, weight in weights:
478-
if name.startswith("bert."):
479-
yield (name[len("bert."):], weight)
480-
else:
481-
self_weights.append((name, weight))
477+
for name, weight in weights:
478+
if name.startswith("bert."):
479+
bert_weights.append((name, weight))
480+
else:
481+
classifier_weights.append((name, weight))
482482

483-
self.bert.load_weights(weight_filter())
483+
loader = AutoWeightsLoader(self)
484+
loaded_params = loader.load_weights(bert_weights)
484485

485486
params_dict = dict(self.named_parameters())
486-
487-
for name, loaded_weight in self_weights:
488-
if name.startswith("classifier"):
487+
for name, loaded_weight in classifier_weights:
488+
if name in params_dict:
489489
param = params_dict[name]
490490
weight_loader = getattr(param, "weight_loader",
491491
default_weight_loader)
492492
weight_loader(param, loaded_weight)
493+
loaded_params.add(name)
494+
495+
return loaded_params
493496

494497
def pooler(
495498
self,

0 commit comments

Comments
 (0)