Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,18 @@ def weight_loader(self,
block_n)
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n)
elif isinstance(param, PerTensorScaleParameter):
shard_offset = loaded_shard_id
shard_size = 1
Comment on lines -434 to -436
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PerTensorScaleParameter is a subclass of BasevLLMParameter, and PerTensorScaleParameter.load_merged_column_weight is doing the right weight loading, so I reckon it a code quality improvement.

TLDR: We have previously set shard_offset and shard_size to something not what the name represents just to make the following weight overloading part(param.data[shard_offset:shard_offset + shard_size] = loaded_weight) working, but PerTensorScaleParameter.load_merged_column_weight which is just BasevLLMParameter._assert_and_load is doing exactly the same.

else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]

param.data[shard_offset:shard_offset + shard_size] = loaded_weight
if isinstance(param, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=0)
Comment on lines -441 to +443
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I do think calling the kwarg tp_rank is misleading, it does the trick and I can't come up with a better naming.

else:
param.data[shard_offset:shard_offset + shard_size] = loaded_weight


@CustomOp.register("column_parallel_linear")
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):

param_data = self.data

tp_rank = get_tensor_model_parallel_rank()
tp_rank = kwargs.get("tp_rank", get_tensor_model_parallel_rank())
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
Expand Down Expand Up @@ -456,4 +456,4 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
shard_offset=shard_offset,
bitblas_tile_size=bitblas_tile_size)

return shard_size, shard_offset
return shard_size, shard_offset