diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index b78081155e2ba..aaa366335d196 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -13,8 +13,12 @@ CompressedTensorsW8A8StaticTensor) -def test_compressed_tensors_w8a8_static_setup(vllm_runner): - model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" +@pytest.mark.parametrize("model_args", [ + ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor"), + ("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel"), +]) +def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): + model_path, strategy = model_args with vllm_runner(model_path, enforce_eager=True) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 layer = model.model.layers[0] @@ -33,12 +37,14 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner): assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) + assert qkv_proj.scheme.strategy == strategy assert qkv_proj.weight.dtype is torch.int8 assert o_proj.weight.dtype is torch.int8 assert gate_up_proj.weight.dtype is torch.int8 - assert qkv_proj.weight_scale.shard_splitter is not None - assert qkv_proj.weight_scale.logical_widths is not None + if qkv_proj.scheme.strategy == "tensor": + assert qkv_proj.weight_scale.shard_splitter is not None + assert qkv_proj.weight_scale.logical_widths is not None assert qkv_proj.input_scale.dtype is torch.float32 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 347a052a663da..44dd024afe74d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -85,8 +85,11 @@ def get_config_filenames(cls) -> List[str]: def _is_static_tensor_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 - is_tensor = (weight_quant.strategy == input_quant.strategy == - QuantizationStrategy.TENSOR.value) + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_tensor = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TENSOR.value) is_symmetric = weight_quant.symmetric and input_quant.symmetric is_static = not weight_quant.dynamic and not input_quant.dynamic @@ -131,7 +134,8 @@ def _get_schema(self, weight_quant: BaseModel, if self.quant_format == CompressionFormat.int_quantized.value: if self._is_static_tensor_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8StaticTensor() + return CompressedTensorsW8A8StaticTensor( + strategy=weight_quant.strategy) if self._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8DynamicToken( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py new file mode 100644 index 0000000000000..efed79ec7a11c --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py @@ -0,0 +1,84 @@ +from typing import Callable, List, Tuple, Union + +import torch +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationStrategy) +from vllm.model_executor.utils import set_weight_attrs + + +class CompressedTensorsW8A8(CompressedTensorsScheme): + + def __init__(self, strategy: str): + self.strategy = strategy + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + assert isinstance(shard_id, str) + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + def scales_shard_splitter( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id = self._shard_id_as_int(shard_id) + offset = sum(logical_widths[:shard_id]) + size = logical_widths[shard_id] + # update loaded weight with copies for broadcast. + loaded_weight = loaded_weight.repeat(size) + return param[offset:offset + size], loaded_weight + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + is_tensor_partitioned = len(output_partition_sizes) != 1 + weight_scale_dim = sum(output_partition_sizes) if ( + is_tensor_partitioned + or self.strategy == QuantizationStrategy.CHANNEL) else 1 + + shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, ) + if self.strategy == QuantizationStrategy.CHANNEL: + shape = (weight_scale_dim, 1) + + weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32), + requires_grad=False) + + layer.register_parameter("weight_scale", weight_scale) + set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + requires_grad=False) + + layer.register_parameter("weight", weight) + set_weight_attrs( + weight, { + "input_dim": 1, + "output_dim": 0, + "weight_loader": weight_loader, + "logical_widths": output_partition_sizes + }) + + # Don't need a shard_splitter for channel-wise quantization + # Use the default loading method + if self.strategy == QuantizationStrategy.CHANNEL: + set_weight_attrs(weight_scale, { + "output_dim": 0, + }) + else: + set_weight_attrs( + weight_scale, { + "logical_widths": output_partition_sizes, + "shard_splitter": self.scales_shard_splitter, + }) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py index 37610c9c2898b..5fc05b8e682d6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py @@ -1,42 +1,15 @@ -from typing import Callable, List, Tuple, Union +from typing import Callable, List import torch -from torch.nn import Parameter from vllm import _custom_ops as custom_ops -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - QuantizationStrategy) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501 + CompressedTensorsW8A8) __all__ = ["CompressedTensorsW8A8DynamicToken"] -class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): - - def __init__(self, strategy: str): - self.strategy = strategy - - def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: - if isinstance(shard_id, int): - return shard_id - - assert isinstance(shard_id, str) - qkv_idxs = {"q": 0, "k": 1, "v": 2} - assert shard_id in qkv_idxs - return qkv_idxs[shard_id] - - def scales_shard_splitter( - self, param: torch.Tensor, loaded_weight: torch.Tensor, - shard_id: Union[str, int], - logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - shard_id = self._shard_id_as_int(shard_id) - offset = sum(logical_widths[:shard_id]) - size = logical_widths[shard_id] - # update loaded weight with copies for broadcast. - loaded_weight = loaded_weight.repeat(size) - return param[offset:offset + size], loaded_weight +class CompressedTensorsW8A8DynamicToken(CompressedTensorsW8A8): def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], @@ -44,54 +17,12 @@ def create_weights(self, layer: torch.nn.Module, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - # When the scales have a single value, it is required that they be - # on the CPU for performance and CUDA Graphs compatibility. Please - # refer to the comment in - # CompressedTensorsW8A8StaticTensor::create_weights for further - # information. - is_tensor_partitioned = len(output_partition_sizes) != 1 - # when doing channel-wise quantization, number of scales - # is equal to output_dim - weight_scale_dim = sum(output_partition_sizes) if ( - is_tensor_partitioned - or self.strategy == QuantizationStrategy.CHANNEL) else 1 - - shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, ) - if self.strategy == QuantizationStrategy.CHANNEL: - shape = (weight_scale_dim, 1) - - weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32), - requires_grad=False) - - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=torch.int8), - requires_grad=False) - - layer.register_parameter("weight", weight) - set_weight_attrs( - weight, { - "input_dim": 1, - "output_dim": 0, - "weight_loader": weight_loader, - "logical_widths": output_partition_sizes - }) - - layer.register_parameter("weight_scale", weight_scale) - set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) - - # Don't need a shard_splitter for channel-wise quantization - # Use the default loading method - if self.strategy == QuantizationStrategy.CHANNEL: - set_weight_attrs(weight_scale, { - "output_dim": 0, - }) - else: - set_weight_attrs( - weight_scale, { - "logical_widths": output_partition_sizes, - "shard_splitter": self.scales_shard_splitter, - }) + super().create_weights( + layer=layer, + output_partition_sizes=output_partition_sizes, + input_size_per_partition=input_size_per_partition, + params_dtype=params_dtype, + weight_loader=weight_loader) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): weight = layer.weight diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 414e17a061fb4..79f5358a365ed 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -1,37 +1,17 @@ -from typing import Callable, List, Tuple, Union +from typing import Callable, List import torch from torch.nn import Parameter from vllm import _custom_ops as custom_ops -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501 + CompressedTensorsW8A8) from vllm.model_executor.utils import set_weight_attrs __all__ = ["CompressedTensorsW8A8StaticTensor"] -class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): - - def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: - if isinstance(shard_id, int): - return shard_id - - assert isinstance(shard_id, str) - qkv_idxs = {"q": 0, "k": 1, "v": 2} - assert shard_id in qkv_idxs - return qkv_idxs[shard_id] - - def scales_shard_splitter( - self, param: torch.Tensor, loaded_weight: torch.Tensor, - shard_id: Union[str, int], - logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - shard_id = self._shard_id_as_int(shard_id) - offset = sum(logical_widths[:shard_id]) - size = logical_widths[shard_id] - # update loaded weight with copies for broadcast. - loaded_weight = loaded_weight.repeat(size) - return param[offset:offset + size], loaded_weight +class CompressedTensorsW8A8StaticTensor(CompressedTensorsW8A8): def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], @@ -39,41 +19,21 @@ def create_weights(self, layer: torch.nn.Module, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - is_tensor_partitioned = len(output_partition_sizes) != 1 - weight_scale_dim = sum( - output_partition_sizes) if is_tensor_partitioned else 1 + super().create_weights( + layer=layer, + output_partition_sizes=output_partition_sizes, + input_size_per_partition=input_size_per_partition, + params_dtype=params_dtype, + weight_loader=weight_loader) input_scale = Parameter(torch.empty(1, dtype=torch.float32), requires_grad=False) - weight_scale = Parameter(torch.empty(weight_scale_dim, - dtype=torch.float32), - requires_grad=False) - - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=torch.int8), - requires_grad=False) - - layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - "weight_loader": weight_loader, - "input_dim": 1, - "output_dim": 0, - }) layer.register_parameter("input_scale", input_scale) set_weight_attrs(input_scale, { "weight_loader": weight_loader, "ignore_warning": True, }) - layer.register_parameter("weight_scale", weight_scale) - set_weight_attrs( - weight_scale, { - "weight_loader": weight_loader, - "shard_splitter": self.scales_shard_splitter, - "logical_widths": output_partition_sizes, - "ignore_warning": True, - }) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): weight = layer.weight