diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index a18ea5601ba94..e48250670b7cd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -271,10 +271,6 @@ def __init__(self, self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - # Special case for Fp8 scales. - fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", - None) - tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) param_data = param.data @@ -283,11 +279,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # Special case for Fp8 scales. - elif fp8_scales_shard_indexer is not None: - param_data, loaded_weight = fp8_scales_shard_indexer(param_data, - loaded_weight, - shard_id=0) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -781,10 +777,6 @@ def __init__(self, self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - # Special case for Fp8 scales. - fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", - None) - tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data @@ -794,13 +786,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) - # Special case for Fp8 scales. - elif fp8_scales_shard_indexer is not None: - param_data, loaded_weight = fp8_scales_shard_indexer(param_data, - loaded_weight, - shard_id=0) - - if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0: + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) assert param_data.shape == loaded_weight.shape