Skip to content

Commit

Permalink
[Misc] Add channel-wise quantization support for w8a8 dynamic per tok…
Browse files Browse the repository at this point in the history
…en activation quantization (vllm-project#5542)
  • Loading branch information
dsikka authored Jun 18, 2024
1 parent d142b7f commit b4e4f7f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 32 deletions.
13 changes: 9 additions & 4 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def test_compressed_tensors_w8a8_static_setup(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand Down Expand Up @@ -43,15 +43,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):


def test_compressed_tensors_no_enforce_eager(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
with vllm_runner(model_path) as llm:
sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output


def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
])
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand All @@ -60,6 +64,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8


Expand Down
13 changes: 0 additions & 13 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,6 @@ def weight_loader(self,
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")

if fp8_scales_shard_indexer is None:
if len(param_data.shape) == 0:
param_data = param_data.reshape(1)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down Expand Up @@ -686,12 +679,6 @@ def weight_loader(self,
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")

if len(param_data.shape) == 0:
param_data = param_data.reshape(1)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,15 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
is_token_tensor = (weight_quant.strategy
== QuantizationStrategy.TENSOR.value) and (
input_quant.strategy
== QuantizationStrategy.TOKEN.value)
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_dynamic = not weight_quant.dynamic and input_quant.dynamic

return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
return is_8_bits and is_token and is_symmetric and is_dynamic

def _is_w4a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -133,7 +134,8 @@ def _get_schema(self, weight_quant: BaseModel,
return CompressedTensorsW8A8StaticTensor()

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8DynamicToken()
return CompressedTensorsW8A8DynamicToken(
strategy=weight_quant.strategy)

raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.utils import set_weight_attrs

__all__ = ["CompressedTensorsW8A8DynamicToken"]


class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):

def __init__(self, strategy: str):
self.strategy = strategy

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
Expand Down Expand Up @@ -45,11 +50,17 @@ def create_weights(self, layer: torch.nn.Module,
# CompressedTensorsW8A8StaticTensor::create_weights for further
# information.
is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
# when doing channel-wise quantization, number of scales
# is equal to output_dim
weight_scale_dim = sum(output_partition_sizes) if (
is_tensor_partitioned
or self.strategy == QuantizationStrategy.CHANNEL) else 1

shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
if self.strategy == QuantizationStrategy.CHANNEL:
shape = (weight_scale_dim, 1)

weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32),
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
requires_grad=False)

weight = Parameter(torch.empty(sum(output_partition_sizes),
Expand All @@ -67,12 +78,20 @@ def create_weights(self, layer: torch.nn.Module,
})

layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})

# Don't need a shard_splitter for channel-wise quantization
# Use the default loading method
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"output_dim": 0,
})
else:
set_weight_attrs(
weight_scale, {
"logical_widths": output_partition_sizes,
"shard_splitter": self.scales_shard_splitter,
})

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
Expand Down

0 comments on commit b4e4f7f

Please sign in to comment.