From 00087152795ebc9362263d753ac26b805ed50633 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 18 Jun 2024 12:45:05 -0400 Subject: [PATCH] [Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization (#5542) --- tests/quantization/test_compressed_tensors.py | 13 +++++-- vllm/model_executor/layers/linear.py | 13 ------- .../compressed_tensors/compressed_tensors.py | 14 ++++--- .../compressed_tensors_w8a8_dynamictoken.py | 37 ++++++++++++++----- 4 files changed, 45 insertions(+), 32 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index a3767dd454fa8..f35fc8c223056 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -19,7 +19,7 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner): - model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2" + model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" with vllm_runner(model_path, enforce_eager=True) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 layer = model.model.layers[0] @@ -48,15 +48,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner): def test_compressed_tensors_no_enforce_eager(vllm_runner): - model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2" + model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" with vllm_runner(model_path) as llm: sampling_params = SamplingParams() output = llm.generate("Hello world!", sampling_params=sampling_params) assert output -def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner): - model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2" +@pytest.mark.parametrize("model_args", [ + ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"), + ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"), +]) +def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args): + model_path, strategy = model_args with vllm_runner(model_path, dtype=torch.float16) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 layer = model.model.layers[0] @@ -65,6 +69,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken) + assert qkv_proj.scheme.strategy == strategy assert qkv_proj.weight.dtype is torch.int8 diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9598caea47071..a18ea5601ba94 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -476,13 +476,6 @@ def weight_loader(self, "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") - if fp8_scales_shard_indexer is None: - if len(param_data.shape) == 0: - param_data = param_data.reshape(1) - - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) - # UPSTREAM SYNC: needed for LazyCompressedParameter self.loaded_shards.add(loaded_shard_id) assert param_data.shape == loaded_weight.shape @@ -707,12 +700,6 @@ def weight_loader(self, "QKVParallelLinear, assume the weight is the same " "for all partitions.") - if len(param_data.shape) == 0: - param_data = param_data.reshape(1) - - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) - assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 92a84b3c0dd89..347a052a663da 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -95,14 +95,15 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel, def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 - is_token_tensor = (weight_quant.strategy - == QuantizationStrategy.TENSOR.value) and ( - input_quant.strategy - == QuantizationStrategy.TOKEN.value) + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) is_symmetric = weight_quant.symmetric and input_quant.symmetric is_dynamic = not weight_quant.dynamic and input_quant.dynamic - return is_8_bits and is_token_tensor and is_symmetric and is_dynamic + return is_8_bits and is_token and is_symmetric and is_dynamic def _is_w4a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: @@ -133,7 +134,8 @@ def _get_schema(self, weight_quant: BaseModel, return CompressedTensorsW8A8StaticTensor() if self._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8DynamicToken() + return CompressedTensorsW8A8DynamicToken( + strategy=weight_quant.strategy) raise NotImplementedError( "No compressed-tensors compatible scheme was found.") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index d514d7b28cfd9..37610c9c2898b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -6,6 +6,8 @@ from vllm import _custom_ops as custom_ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationStrategy) from vllm.model_executor.utils import set_weight_attrs __all__ = ["CompressedTensorsW8A8DynamicToken"] @@ -13,6 +15,9 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): + def __init__(self, strategy: str): + self.strategy = strategy + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: if isinstance(shard_id, int): return shard_id @@ -45,11 +50,17 @@ def create_weights(self, layer: torch.nn.Module, # CompressedTensorsW8A8StaticTensor::create_weights for further # information. is_tensor_partitioned = len(output_partition_sizes) != 1 - weight_scale_dim = sum( - output_partition_sizes) if is_tensor_partitioned else 1 + # when doing channel-wise quantization, number of scales + # is equal to output_dim + weight_scale_dim = sum(output_partition_sizes) if ( + is_tensor_partitioned + or self.strategy == QuantizationStrategy.CHANNEL) else 1 + + shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, ) + if self.strategy == QuantizationStrategy.CHANNEL: + shape = (weight_scale_dim, 1) - weight_scale = Parameter(torch.empty(weight_scale_dim, - dtype=torch.float32), + weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32), requires_grad=False) weight = Parameter(torch.empty(sum(output_partition_sizes), @@ -67,12 +78,20 @@ def create_weights(self, layer: torch.nn.Module, }) layer.register_parameter("weight_scale", weight_scale) - set_weight_attrs( - weight_scale, { - "weight_loader": weight_loader, - "shard_splitter": self.scales_shard_splitter, - "logical_widths": output_partition_sizes + set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + + # Don't need a shard_splitter for channel-wise quantization + # Use the default loading method + if self.strategy == QuantizationStrategy.CHANNEL: + set_weight_attrs(weight_scale, { + "output_dim": 0, }) + else: + set_weight_attrs( + weight_scale, { + "logical_widths": output_partition_sizes, + "shard_splitter": self.scales_shard_splitter, + }) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): weight = layer.weight