diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c0fcacd1e6ee..ebc4f67f1ad7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -431,14 +431,18 @@ def weight_loader(self, block_n) shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // block_n) - elif isinstance(param, PerTensorScaleParameter): - shard_offset = loaded_shard_id - shard_size = 1 else: shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_size = self.output_sizes[loaded_shard_id] - param.data[shard_offset:shard_offset + shard_size] = loaded_weight + if isinstance(param, BasevLLMParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + tp_rank=0) + else: + param.data[shard_offset:shard_offset + shard_size] = loaded_weight @CustomOp.register("column_parallel_linear") diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 750ee7850268..5466ed50ca42 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -122,7 +122,7 @@ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): param_data = self.data - tp_rank = get_tensor_model_parallel_rank() + tp_rank = kwargs.get("tp_rank", get_tensor_model_parallel_rank()) param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(self.output_dim, @@ -456,4 +456,4 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_offset=shard_offset, bitblas_tile_size=bitblas_tile_size) - return shard_size, shard_offset \ No newline at end of file + return shard_size, shard_offset