Skip to content

Commit f793ed3

Browse files
candyzonedebroy-rh
authored andcommitted
[Perf] Optimize memory peak during EAGLE model loading. (vllm-project#24585)
Signed-off-by: Chen Ding <candy.dc@alibaba-inc.com>
1 parent 0a32b0b commit f793ed3

File tree

3 files changed

+26
-26
lines changed

3 files changed

+26
-26
lines changed

vllm/model_executor/models/deepseek_eagle.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,15 @@ def compute_logits(
229229
return logits
230230

231231
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
232+
233+
def transform(inputs):
234+
name, loaded_weight = inputs
235+
if "lm_head" not in name:
236+
name = "model." + name
237+
return name, loaded_weight
238+
232239
loader = AutoWeightsLoader(
233240
self,
234241
skip_prefixes=None,
235242
)
236-
237-
model_weights = {}
238-
for name, loaded_weight in weights:
239-
if "lm_head" not in name:
240-
name = "model." + name
241-
model_weights[name] = loaded_weight
242-
loader.load_weights(model_weights.items())
243+
loader.load_weights(map(transform, weights))

vllm/model_executor/models/llama4_eagle.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -205,23 +205,21 @@ def forward(
205205

206206
def load_weights(self, weights: Iterable[tuple[str,
207207
torch.Tensor]]) -> None:
208+
209+
def transform(inputs):
210+
name, loaded_weight = inputs
211+
name, weight = self.permute_qk_weight_for_rotary(
212+
name, loaded_weight)
213+
if "lm_head" not in name:
214+
name = "model." + name
215+
return name, weight
216+
208217
loader = AutoWeightsLoader(
209218
self,
210219
# lm_head is tied with target model (Llama4ForCausalLM)
211220
skip_prefixes=(["lm_head."]),
212221
)
213-
214-
model_weights = {}
215-
weights = [
216-
self.permute_qk_weight_for_rotary(name, loaded_weight)
217-
for name, loaded_weight in weights
218-
]
219-
for name, loaded_weight in weights:
220-
if "lm_head" not in name:
221-
name = "model." + name
222-
model_weights[name] = loaded_weight
223-
224-
loader.load_weights(model_weights.items())
222+
loader.load_weights(map(transform, weights))
225223

226224
def get_input_embeddings(
227225
self,

vllm/model_executor/models/llama_eagle.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,15 @@ def forward(
158158
return self.model(input_ids, positions, hidden_states)
159159

160160
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
161+
162+
def transform(inputs):
163+
name, loaded_weight = inputs
164+
if "lm_head" not in name:
165+
name = "model." + name
166+
return name, loaded_weight
167+
161168
loader = AutoWeightsLoader(
162169
self,
163170
skip_prefixes=None,
164171
)
165-
166-
model_weights = {}
167-
for name, loaded_weight in weights:
168-
if "lm_head" not in name:
169-
name = "model." + name
170-
model_weights[name] = loaded_weight
171-
loader.load_weights(model_weights.items())
172+
loader.load_weights(map(transform, weights))

0 commit comments

Comments
 (0)