Skip to content

Commit

Permalink
[ Misc ] Remove fp8_shard_indexer from Col/Row Parallel Linear (Sim…
Browse files Browse the repository at this point in the history
…plify Weight Loading) (vllm-project#5928)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
  • Loading branch information
2 people authored and prashantgupta24 committed Jul 1, 2024
1 parent e567962 commit 06ad77d
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,6 @@ def __init__(self,
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)

tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
param_data = param.data
Expand All @@ -281,11 +277,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)

# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
Expand Down Expand Up @@ -751,10 +747,6 @@ def __init__(self,
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)

tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
param_data = param.data
Expand All @@ -764,13 +756,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)

# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)

if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param_data.shape == loaded_weight.shape
Expand Down

0 comments on commit 06ad77d

Please sign in to comment.