From 06ad77dc8ad229b12ec0dd505198ccb8577f0d5c Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 28 Jun 2024 13:49:57 -0400 Subject: [PATCH] [ Misc ] Remove `fp8_shard_indexer` from Col/Row Parallel Linear (Simplify Weight Loading) (#5928) Co-authored-by: Robert Shaw --- vllm/model_executor/layers/linear.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 45f805547b414..fe7c2a295b70c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -269,10 +269,6 @@ def __init__(self, self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - # Special case for Fp8 scales. - fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", - None) - tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) param_data = param.data @@ -281,11 +277,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # Special case for Fp8 scales. - elif fp8_scales_shard_indexer is not None: - param_data, loaded_weight = fp8_scales_shard_indexer(param_data, - loaded_weight, - shard_id=0) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -751,10 +747,6 @@ def __init__(self, self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - # Special case for Fp8 scales. - fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", - None) - tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data @@ -764,13 +756,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) - # Special case for Fp8 scales. - elif fp8_scales_shard_indexer is not None: - param_data, loaded_weight = fp8_scales_shard_indexer(param_data, - loaded_weight, - shard_id=0) - - if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0: + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) assert param_data.shape == loaded_weight.shape