-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization #5542
[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization #5542
Changes from all commits
666de99
f387746
fd43792
d804d98
80de429
bc50ba7
5d59f7a
3319697
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -468,13 +468,6 @@ def weight_loader(self, | |
"MergedColumnParallelLinear, assume the weight is " | ||
"the same for all partitions.") | ||
|
||
if fp8_scales_shard_indexer is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you just explain why this was about to be removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The models we're getting now have shapes defined for the scales. We were casting the shapes initially since before x.shape returned empty. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. This is much cleaner |
||
if len(param_data.shape) == 0: | ||
param_data = param_data.reshape(1) | ||
|
||
if len(loaded_weight.shape) == 0: | ||
loaded_weight = loaded_weight.reshape(1) | ||
|
||
assert param_data.shape == loaded_weight.shape | ||
param_data.copy_(loaded_weight) | ||
|
||
|
@@ -686,12 +679,6 @@ def weight_loader(self, | |
"QKVParallelLinear, assume the weight is the same " | ||
"for all partitions.") | ||
|
||
if len(param_data.shape) == 0: | ||
param_data = param_data.reshape(1) | ||
|
||
if len(loaded_weight.shape) == 0: | ||
loaded_weight = loaded_weight.reshape(1) | ||
|
||
assert param_data.shape == loaded_weight.shape | ||
param_data.copy_(loaded_weight) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,18 @@ | |
from vllm import _custom_ops as custom_ops | ||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||
CompressedTensorsScheme) | ||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to do the same thing for static per tensor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do what? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we support channelwise for staticpertensor already? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These changes only impact dynamic per token it seems, so I was jw There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup - just dynamic per token so far. Easy addition to the static per tensor scheme - I can add it as part of this PR. Will be a few code lines to update. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's do as follow up |
||
QuantizationStrategy) | ||
from vllm.model_executor.utils import set_weight_attrs | ||
|
||
__all__ = ["CompressedTensorsW8A8DynamicToken"] | ||
|
||
|
||
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): | ||
|
||
def __init__(self, strategy: str): | ||
self.strategy = strategy | ||
|
||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: | ||
if isinstance(shard_id, int): | ||
return shard_id | ||
|
@@ -45,11 +50,17 @@ def create_weights(self, layer: torch.nn.Module, | |
# CompressedTensorsW8A8StaticTensor::create_weights for further | ||
# information. | ||
is_tensor_partitioned = len(output_partition_sizes) != 1 | ||
weight_scale_dim = sum( | ||
output_partition_sizes) if is_tensor_partitioned else 1 | ||
# when doing channel-wise quantization, number of scales | ||
# is equal to output_dim | ||
weight_scale_dim = sum(output_partition_sizes) if ( | ||
is_tensor_partitioned | ||
or self.strategy == QuantizationStrategy.CHANNEL) else 1 | ||
|
||
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, ) | ||
if self.strategy == QuantizationStrategy.CHANNEL: | ||
shape = (weight_scale_dim, 1) | ||
|
||
weight_scale = Parameter(torch.empty(weight_scale_dim, | ||
dtype=torch.float32), | ||
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32), | ||
requires_grad=False) | ||
|
||
weight = Parameter(torch.empty(sum(output_partition_sizes), | ||
|
@@ -67,12 +78,20 @@ def create_weights(self, layer: torch.nn.Module, | |
}) | ||
|
||
layer.register_parameter("weight_scale", weight_scale) | ||
set_weight_attrs( | ||
weight_scale, { | ||
"weight_loader": weight_loader, | ||
"shard_splitter": self.scales_shard_splitter, | ||
"logical_widths": output_partition_sizes | ||
set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) | ||
|
||
# Don't need a shard_splitter for channel-wise quantization | ||
# Use the default loading method | ||
if self.strategy == QuantizationStrategy.CHANNEL: | ||
set_weight_attrs(weight_scale, { | ||
"output_dim": 0, | ||
}) | ||
else: | ||
set_weight_attrs( | ||
weight_scale, { | ||
"logical_widths": output_partition_sizes, | ||
"shard_splitter": self.scales_shard_splitter, | ||
}) | ||
|
||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): | ||
weight = layer.weight | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do these models need to change? Do the existing ones stop working?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes on compressed-tensor's end in terms of explicit shapes for the scales (i.e. x.shape has a len > 1 now and before it did not).
Sara actually updated the old models last night so these can be reverted back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to switch back to make the PR cleaner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm - we are green so let's go