Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization #5542

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,6 @@ def weight_loader(self,
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")

if fp8_scales_shard_indexer is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just explain why this was about to be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The models we're getting now have shapes defined for the scales. We were casting the shapes initially since before x.shape returned empty.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. This is much cleaner

if len(param_data.shape) == 0:
param_data = param_data.reshape(1)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down Expand Up @@ -686,12 +679,6 @@ def weight_loader(self,
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")

if len(param_data.shape) == 0:
param_data = param_data.reshape(1)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down Expand Up @@ -784,7 +771,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
shard_id=0)

if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
loaded_weight = loaded_weight.reshape(1, 1)
dsikka marked this conversation as resolved.
Show resolved Hide resolved

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
is_token_tensor = (weight_quant.strategy
== QuantizationStrategy.TENSOR.value) and (
input_quant.strategy
== QuantizationStrategy.TOKEN.value)
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_dynamic = not weight_quant.dynamic and input_quant.dynamic

return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
return is_8_bits and is_token and is_symmetric and is_dynamic

def _is_w4a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand All @@ -118,7 +119,8 @@ def _get_schema(self, weight_quant: BaseModel,
return CompressedTensorsW8A8StaticTensor()

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8DynamicToken()
return CompressedTensorsW8A8DynamicToken(
strategy=weight_quant.strategy)

raise NotImplementedError("Scheme not supported.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do the same thing for static per tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do what?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support channelwise for staticpertensor already?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes only impact dynamic per token it seems, so I was jw

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup - just dynamic per token so far. Easy addition to the static per tensor scheme - I can add it as part of this PR. Will be a few code lines to update.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do as follow up

QuantizationStrategy)
from vllm.model_executor.utils import set_weight_attrs

__all__ = ["CompressedTensorsW8A8DynamicToken"]


class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):

def __init__(self, strategy: str):
self.strategy = strategy

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
Expand All @@ -31,6 +36,9 @@ def scales_shard_splitter(
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
# parameter defined for scale is 2D; expand
if len(loaded_weight.shape) == 1:
loaded_weight = torch.unsqueeze(loaded_weight, -1)
return param[offset:offset + size], loaded_weight

def create_weights(self, layer: torch.nn.Module,
Expand All @@ -45,13 +53,17 @@ def create_weights(self, layer: torch.nn.Module,
# CompressedTensorsW8A8StaticTensor::create_weights for further
# information.
is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
# when doing channel-wise quantization, number of scales
# is equal to output_dim
weight_scale_dim = sum(output_partition_sizes) if (
is_tensor_partitioned
or self.strategy == QuantizationStrategy.CHANNEL) else 1

weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)

weight_scale = Parameter(torch.empty(weight_scale_dim,
1,
dtype=torch.float32),
requires_grad=False)

Expand All @@ -67,11 +79,19 @@ def create_weights(self, layer: torch.nn.Module,

layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(
weight_scale, {
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes

# Don't need a shard_splitter for channel-wise quantization
# Use the default loading method
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"output_dim": 0,
})
else:
set_weight_attrs(
weight_scale, {
"logical_widths": output_partition_sizes,
"shard_splitter": self.scales_shard_splitter,
})

layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
Expand Down
Loading