Skip to content

Commit

Permalink
simon changes
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 31, 2025
1 parent a57cd3d commit 548ec44
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def is_deepseek_mla(self) -> bool:
# TODO add deepseek_v3
return hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
in ('deepseek_v2'))
in ('deepseek_v2', 'deepseek_v3'))

def get_head_size(self) -> int:
# TODO remove hard code
Expand Down
15 changes: 15 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
with device_loading_context(
module, torch.device(device_config.device)):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading()
return model.eval()


Expand Down Expand Up @@ -633,6 +638,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading()
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
Expand Down Expand Up @@ -1369,6 +1379,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading()
return model.eval()


Expand Down

0 comments on commit 548ec44

Please sign in to comment.