Skip to content

Commit a204b47

Browse files
tywuAMDMu Huai
authored andcommitted
[BugFix][TritonMLA] Process weights after model loading for GGUF (vllm-project#14555)
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent 9d68d2d commit a204b47

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

vllm/model_executor/model_loader/loader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1330,11 +1330,14 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
13301330
local_model_path, gguf_weights_map):
13311331
model_config.hf_config.update({"tie_word_embeddings": True})
13321332

1333+
target_device = torch.device(device_config.device)
13331334
with set_default_torch_dtype(model_config.dtype):
1334-
with torch.device(device_config.device):
1335+
with target_device:
13351336
model = _initialize_model(vllm_config=vllm_config)
13361337
model.load_weights(
13371338
self._get_weights_iterator(local_model_path, gguf_weights_map))
1339+
1340+
_process_weights_after_loading(model, model_config, target_device)
13381341
return model
13391342

13401343

0 commit comments

Comments
 (0)