Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization #5542

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def test_compressed_tensors_w8a8_static_setup(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do these models need to change? Do the existing ones stop working?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes on compressed-tensor's end in terms of explicit shapes for the scales (i.e. x.shape has a len > 1 now and before it did not).

Sara actually updated the old models last night so these can be reverted back.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to switch back to make the PR cleaner

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm - we are green so let's go

model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand Down Expand Up @@ -43,15 +43,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):


def test_compressed_tensors_no_enforce_eager(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
with vllm_runner(model_path) as llm:
sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output


def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
])
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand All @@ -60,6 +64,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8


Expand Down
13 changes: 0 additions & 13 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,6 @@ def weight_loader(self,
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")

if fp8_scales_shard_indexer is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just explain why this was about to be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The models we're getting now have shapes defined for the scales. We were casting the shapes initially since before x.shape returned empty.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. This is much cleaner

if len(param_data.shape) == 0:
param_data = param_data.reshape(1)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down Expand Up @@ -686,12 +679,6 @@ def weight_loader(self,
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")

if len(param_data.shape) == 0:
param_data = param_data.reshape(1)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,15 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
is_token_tensor = (weight_quant.strategy
== QuantizationStrategy.TENSOR.value) and (
input_quant.strategy
== QuantizationStrategy.TOKEN.value)
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_dynamic = not weight_quant.dynamic and input_quant.dynamic

return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
return is_8_bits and is_token and is_symmetric and is_dynamic

def _is_w4a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -133,7 +134,8 @@ def _get_schema(self, weight_quant: BaseModel,
return CompressedTensorsW8A8StaticTensor()

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8DynamicToken()
return CompressedTensorsW8A8DynamicToken(
strategy=weight_quant.strategy)

raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do the same thing for static per tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do what?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support channelwise for staticpertensor already?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes only impact dynamic per token it seems, so I was jw

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup - just dynamic per token so far. Easy addition to the static per tensor scheme - I can add it as part of this PR. Will be a few code lines to update.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do as follow up

QuantizationStrategy)
from vllm.model_executor.utils import set_weight_attrs

__all__ = ["CompressedTensorsW8A8DynamicToken"]


class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):

def __init__(self, strategy: str):
self.strategy = strategy

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
Expand Down Expand Up @@ -45,11 +50,17 @@ def create_weights(self, layer: torch.nn.Module,
# CompressedTensorsW8A8StaticTensor::create_weights for further
# information.
is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
# when doing channel-wise quantization, number of scales
# is equal to output_dim
weight_scale_dim = sum(output_partition_sizes) if (
is_tensor_partitioned
or self.strategy == QuantizationStrategy.CHANNEL) else 1

shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
if self.strategy == QuantizationStrategy.CHANNEL:
shape = (weight_scale_dim, 1)

weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32),
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
requires_grad=False)

weight = Parameter(torch.empty(sum(output_partition_sizes),
Expand All @@ -67,12 +78,20 @@ def create_weights(self, layer: torch.nn.Module,
})

layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})

# Don't need a shard_splitter for channel-wise quantization
# Use the default loading method
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"output_dim": 0,
})
else:
set_weight_attrs(
weight_scale, {
"logical_widths": output_partition_sizes,
"shard_splitter": self.scales_shard_splitter,
})

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
Expand Down
Loading