Skip to content

Commit 91373a0

Browse files
authored
Fix head_dim not existing in all model configs (Transformers backend) (#14141)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 848a643 commit 91373a0

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from vllm.attention import Attention
2626
from vllm.config import VllmConfig
2727
from vllm.distributed import get_tensor_model_parallel_world_size
28-
from vllm.distributed.utils import divide
2928
from vllm.logger import init_logger
3029
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
3130
ReplicatedLinear,
@@ -128,10 +127,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
128127

129128
config = vllm_config.model_config.hf_config
130129
cache_config = vllm_config.cache_config
130+
model_config = vllm_config.model_config
131+
parallel_config = vllm_config.parallel_config
131132

132133
self.config = config
133-
self.vocab_size = config.vocab_size
134-
self.unpadded_vocab_size = config.vocab_size
134+
self.vocab_size = model_config.get_vocab_size()
135+
self.unpadded_vocab_size = model_config.get_vocab_size()
135136

136137
self.model: PreTrainedModel = AutoModel.from_config(
137138
self.config,
@@ -145,15 +146,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
145146
self.apply_base_model_tp_plan(self.model)
146147

147148
# Attention modifications (assumes 1 attention op per hidden layer)
148-
tp_size = get_tensor_model_parallel_world_size()
149+
num_heads = model_config.get_num_attention_heads(parallel_config)
150+
head_size = model_config.get_head_size()
151+
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
149152
self.attention_instances = [
150153
Attention(
151-
num_heads=divide(config.num_attention_heads, tp_size),
152-
head_size=config.head_dim,
154+
num_heads=num_heads,
155+
head_size=head_size,
153156
# NOTE: We use Llama scale as default, if it's set by
154157
# Transformers, it's updated in vllm_flash_attention_forward
155-
scale=config.head_dim**-0.5,
156-
num_kv_heads=divide(config.num_key_value_heads, tp_size),
158+
scale=head_size**-0.5,
159+
num_kv_heads=num_kv_heads,
157160
cache_config=cache_config,
158161
quant_config=self.quant_config,
159162
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
@@ -163,7 +166,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
163166
self.replace_vocab_embed_class(self.model)
164167

165168
# ForCausalLM modifications
166-
self.lm_head = ParallelLMHead(config.vocab_size,
169+
self.lm_head = ParallelLMHead(self.vocab_size,
167170
config.hidden_size,
168171
quant_config=self.quant_config,
169172
prefix=maybe_prefix(prefix, "lm_head"))
@@ -172,7 +175,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
172175

173176
logit_scale = getattr(config, "logit_scale", 1.0)
174177
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
175-
config.vocab_size, logit_scale)
178+
self.vocab_size, logit_scale)
176179
self.sampler = get_sampler()
177180

178181
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
@@ -203,12 +206,12 @@ def replace_vocab_embed_class(self, module: nn.Module):
203206
new_module = VocabParallelEmbedding(
204207
self.vocab_size,
205208
self.config.hidden_size,
206-
org_num_embeddings=self.config.vocab_size,
209+
org_num_embeddings=self.vocab_size,
207210
quant_config=None,
208211
)
209212
log_replacement("input embedding", self.model.get_input_embeddings(),
210213
new_module)
211-
self.model.set_input_embeddings(new_module)
214+
module.set_input_embeddings(new_module)
212215

213216
def forward(
214217
self,

0 commit comments

Comments
 (0)