File tree Expand file tree Collapse file tree 3 files changed +26
-26
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 3 files changed +26
-26
lines changed Original file line number Diff line number Diff line change @@ -229,14 +229,15 @@ def compute_logits(
229229 return logits
230230
231231 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
232+
233+ def transform (inputs ):
234+ name , loaded_weight = inputs
235+ if "lm_head" not in name :
236+ name = "model." + name
237+ return name , loaded_weight
238+
232239 loader = AutoWeightsLoader (
233240 self ,
234241 skip_prefixes = None ,
235242 )
236-
237- model_weights = {}
238- for name , loaded_weight in weights :
239- if "lm_head" not in name :
240- name = "model." + name
241- model_weights [name ] = loaded_weight
242- loader .load_weights (model_weights .items ())
243+ loader .load_weights (map (transform , weights ))
Original file line number Diff line number Diff line change @@ -205,23 +205,21 @@ def forward(
205205
206206 def load_weights (self , weights : Iterable [tuple [str ,
207207 torch .Tensor ]]) -> None :
208+
209+ def transform (inputs ):
210+ name , loaded_weight = inputs
211+ name , weight = self .permute_qk_weight_for_rotary (
212+ name , loaded_weight )
213+ if "lm_head" not in name :
214+ name = "model." + name
215+ return name , weight
216+
208217 loader = AutoWeightsLoader (
209218 self ,
210219 # lm_head is tied with target model (Llama4ForCausalLM)
211220 skip_prefixes = (["lm_head." ]),
212221 )
213-
214- model_weights = {}
215- weights = [
216- self .permute_qk_weight_for_rotary (name , loaded_weight )
217- for name , loaded_weight in weights
218- ]
219- for name , loaded_weight in weights :
220- if "lm_head" not in name :
221- name = "model." + name
222- model_weights [name ] = loaded_weight
223-
224- loader .load_weights (model_weights .items ())
222+ loader .load_weights (map (transform , weights ))
225223
226224 def get_input_embeddings (
227225 self ,
Original file line number Diff line number Diff line change @@ -158,14 +158,15 @@ def forward(
158158 return self .model (input_ids , positions , hidden_states )
159159
160160 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
161+
162+ def transform (inputs ):
163+ name , loaded_weight = inputs
164+ if "lm_head" not in name :
165+ name = "model." + name
166+ return name , loaded_weight
167+
161168 loader = AutoWeightsLoader (
162169 self ,
163170 skip_prefixes = None ,
164171 )
165-
166- model_weights = {}
167- for name , loaded_weight in weights :
168- if "lm_head" not in name :
169- name = "model." + name
170- model_weights [name ] = loaded_weight
171- loader .load_weights (model_weights .items ())
172+ loader .load_weights (map (transform , weights ))
You can’t perform that action at this time.
0 commit comments