diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index fa8a261db7d7..fd88eac55cb5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -262,7 +262,7 @@ def __init__( self.tp_size = (get_tensor_model_parallel_world_size() if not disable_tp else 1) - def __post_init__(self): + def update_param_tp_status(self): for param in self.parameters(): if isinstance(param, BasevLLMParameter): param.tp_rank = self.tp_rank @@ -459,6 +459,7 @@ def __init__( }) else: self.register_parameter("bias", None) + self.update_param_tp_status() def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): @@ -1250,6 +1251,7 @@ def __init__( }) else: self.register_parameter("bias", None) + self.update_param_tp_status() def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): input_dim = getattr(param, "input_dim", None) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index de22cceb45d1..65e0b7062153 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -270,7 +270,8 @@ def create_weights( layer.weight_block_size = None if self.block_quant: - tp_size = get_tensor_model_parallel_world_size() + tp_size = getattr(layer, "tp_size", + get_tensor_model_parallel_world_size()) assert self.quant_config.weight_block_size is not None layer.weight_block_size = self.quant_config.weight_block_size block_n, block_k = (