-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization #5542
[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization #5542
Changes from 5 commits
666de99
f387746
fd43792
d804d98
80de429
bc50ba7
5d59f7a
3319697
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,18 @@ | |
from vllm import _custom_ops as custom_ops | ||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||
CompressedTensorsScheme) | ||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to do the same thing for static per tensor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do what? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we support channelwise for staticpertensor already? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These changes only impact dynamic per token it seems, so I was jw There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup - just dynamic per token so far. Easy addition to the static per tensor scheme - I can add it as part of this PR. Will be a few code lines to update. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's do as follow up |
||
QuantizationStrategy) | ||
from vllm.model_executor.utils import set_weight_attrs | ||
|
||
__all__ = ["CompressedTensorsW8A8DynamicToken"] | ||
|
||
|
||
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): | ||
|
||
def __init__(self, strategy: str): | ||
self.strategy = strategy | ||
|
||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: | ||
if isinstance(shard_id, int): | ||
return shard_id | ||
|
@@ -31,6 +36,9 @@ def scales_shard_splitter( | |
size = logical_widths[shard_id] | ||
# update loaded weight with copies for broadcast. | ||
loaded_weight = loaded_weight.repeat(size) | ||
# parameter defined for scale is 2D; expand | ||
if len(loaded_weight.shape) == 1: | ||
loaded_weight = torch.unsqueeze(loaded_weight, -1) | ||
return param[offset:offset + size], loaded_weight | ||
|
||
def create_weights(self, layer: torch.nn.Module, | ||
|
@@ -45,13 +53,17 @@ def create_weights(self, layer: torch.nn.Module, | |
# CompressedTensorsW8A8StaticTensor::create_weights for further | ||
# information. | ||
is_tensor_partitioned = len(output_partition_sizes) != 1 | ||
weight_scale_dim = sum( | ||
output_partition_sizes) if is_tensor_partitioned else 1 | ||
# when doing channel-wise quantization, number of scales | ||
# is equal to output_dim | ||
weight_scale_dim = sum(output_partition_sizes) if ( | ||
is_tensor_partitioned | ||
or self.strategy == QuantizationStrategy.CHANNEL) else 1 | ||
|
||
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8), | ||
requires_grad=False) | ||
|
||
weight_scale = Parameter(torch.empty(weight_scale_dim, | ||
1, | ||
dtype=torch.float32), | ||
requires_grad=False) | ||
|
||
|
@@ -67,11 +79,19 @@ def create_weights(self, layer: torch.nn.Module, | |
|
||
layer.register_parameter("weight_scale", weight_scale) | ||
set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) | ||
set_weight_attrs( | ||
weight_scale, { | ||
"shard_splitter": self.scales_shard_splitter, | ||
"logical_widths": output_partition_sizes | ||
|
||
# Don't need a shard_splitter for channel-wise quantization | ||
# Use the default loading method | ||
if self.strategy == QuantizationStrategy.CHANNEL: | ||
set_weight_attrs(weight_scale, { | ||
"output_dim": 0, | ||
}) | ||
else: | ||
set_weight_attrs( | ||
weight_scale, { | ||
"logical_widths": output_partition_sizes, | ||
"shard_splitter": self.scales_shard_splitter, | ||
}) | ||
|
||
layer.register_parameter("weight_zero_point", weight_zero_point) | ||
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just explain why this was about to be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The models we're getting now have shapes defined for the scales. We were casting the shapes initially since before x.shape returned empty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. This is much cleaner