From a5581fc6892479b7b2192ceae82c249f57770516 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 17 Dec 2024 05:38:24 +0000 Subject: [PATCH 01/12] Add: Support for Sparse24Bitmask Compressed Models Signed-off-by: Rahul Tuli --- .../compressed_tensors/compressed_tensors.py | 40 +++++-- .../schemes/compressed_tensors_24.py | 106 ++++++++++++++++-- vllm/model_executor/parameter.py | 26 ++++- 3 files changed, 153 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 1a11b2419cc88..98e34b945ac24 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -420,15 +420,19 @@ def get_scheme(self, return None # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel - scheme = CompressedTensors24(quantized=weight_quant is not None - or input_quant is not None, - weight_quant=weight_quant, - input_quant=input_quant) + scheme = CompressedTensors24( + quantized=weight_quant is not None or input_quant is not None, + weight_quant=weight_quant, + input_quant=input_quant, + model_compression_config=self._get_model_compression_config( + sparsity_scheme), + ) elif weight_quant is None: logger.warning_once("Acceleration for non-quantized schemes is " "not supported by Compressed Tensors. " "Falling back to UnquantizedLinearMethod") return None + else: # Find the quant_scheme scheme = self._get_scheme_from_parts( # type: ignore @@ -478,10 +482,17 @@ def supports_cutlass_24( :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise """ - is_valid_sparsity = (sparsity_scheme is not None - and sparsity_scheme.sparsity_structure - == SparsityStructure.TWO_FOUR.value - and sparsity_scheme.format == "dense") + is_valid_sparsity_structure = (sparsity_scheme is not None + and sparsity_scheme.sparsity_structure + == SparsityStructure.TWO_FOUR.value) + valid_compressors = { + CompressionFormat.dense.value, + CompressionFormat.sparse_24_bitmask.value + } + + is_valid_sparsity = (is_valid_sparsity_structure + and sparsity_scheme.format in valid_compressors) + if not is_valid_sparsity: return False @@ -512,6 +523,19 @@ def supports_cutlass_24( return weight_quant.num_bits == input_quant.num_bits == 8 + def _get_model_compression_config( + self, sparsity_scheme: Optional[SparsityCompressionConfig] = None): + """ + Get the model compressor config from the sparsity scheme + + :param sparsity_scheme: The sparsity scheme + :return: The model compressor config + """ + if sparsity_scheme is None or sparsity_scheme.format == "dense": + return None + + return self.config + class CompressedTensorsLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 84f924b236af9..21333003d821d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,18 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch +from compressed_tensors import CompressionFormat, ModelCompressor from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) +from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, sparse_cutlass_supported) from vllm.model_executor.parameter import (BasevLLMParameter, + BitMaskShapeParameter, ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -22,14 +27,24 @@ class CompressedTensors24(CompressedTensorsScheme): - def __init__(self, - quantized: bool = False, - weight_quant: Optional[QuantizationArgs] = None, - input_quant: Optional[QuantizationArgs] = None): + def __init__( + self, + quantized: bool = False, + weight_quant: Optional[QuantizationArgs] = None, + input_quant: Optional[QuantizationArgs] = None, + model_compression_config: Optional[Dict[str, Any]] = None, + ): self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant + self.model_compressor = ( + ModelCompressor.from_compression_config(model_compression_config) + if model_compression_config is not None else None) + self.do_sparse_decompress = ( + self.model_compressor is not None + and self.model_compressor.sparsity_config.format + == CompressionFormat.sparse_24_bitmask.value) @classmethod def get_min_capability(cls) -> int: @@ -49,6 +64,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, self.output_dtype = params_dtype layer.logical_widths = output_partition_sizes + layer.input_size = input_size + layer.input_size_per_partition = input_size_per_partition self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) # parameter to store uncompressed weight @@ -59,6 +76,34 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_dim=1, output_dim=0, weight_loader=weight_loader) + if self.do_sparse_decompress: + assert all( + partition_size % 8 == 0 + for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for 2:4 compressed models" + + shape = BitMaskShapeParameter(data=torch.empty( + 2 * len(output_partition_sizes), 1, dtype=torch.uint64), + weight_loader=weight_loader) + compressed = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=self.weights_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + bitmask = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 8, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("shape", shape) + layer.register_parameter("compressed", compressed) + layer.register_parameter("bitmask", bitmask) # Check if quantized, not just 2:4 Sparse if self.quantized: @@ -114,6 +159,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: :param layer: The layer with the weights to be processed """ + if self.do_sparse_decompress: + layer.weight.data = self._decompress_bitmask_compressed_weight( + compressed=layer.compressed, + bitmask=layer.bitmask, + layer=layer, + ) + # torch.compile workaround if hasattr(layer, "input_scale"): layer.input_scale = torch.nn.Parameter(layer.input_scale.data, @@ -203,8 +255,42 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: raise ValueError("Quantization type not supported by Cutlass") - -def check_24(tensor): - new_tensor = tensor.view(-1, 4) - zero_counts = (new_tensor == 0).sum(dim=1) - return (zero_counts >= 2).all().item() + def _decompress_bitmask_compressed_weight( + self, compressed: torch.Tensor, bitmask: torch.Tensor, + layer: torch.nn.Module) -> torch.Tensor: + + sparsity_compressor = self.model_compressor.sparsity_compressor + + def _process_split(bitmask_compressed_weight: torch.Tensor, shape, + bitmask: torch.Tensor) -> torch.Tensor: + weight_data = dict( + compressed=bitmask_compressed_weight, + shape=shape, + bitmask=bitmask, + ) + return sparsity_compressor.decompress_weight(weight_data) + + split_weights = None + split_bitmask = None + split_shape = None + + if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): + split_weights = torch.split(compressed, layer.logical_widths) + split_bitmask = torch.split(bitmask, layer.logical_widths) + split_shape = [(out, layer.input_size_per_partition) + for out in layer.logical_widths] + + if split_weights is not None: + decompressed_shards = [ + _process_split(compressed_weight, shape, bitmask) + for compressed_weight, shape, bitmask in zip( + split_weights, split_shape, split_bitmask) + ] + decompressed = combine_shards(decompressed_shards) + else: + decompressed = sparsity_compressor.decompress_weight( + dict(compressed=compressed, + shape=(layer.logical_widths[0], + layer.input_size_per_partition), + bitmask=bitmask)) + return decompressed diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 2b1294bf7baa3..65dadf5dcad4d 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -13,7 +13,8 @@ __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter", - "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" + "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", + "BitMaskShapeParameter" ] logger = init_logger(__name__) @@ -431,3 +432,26 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_offset=shard_offset, marlin_tile_size=marlin_tile_size) return shard_size, shard_offset + + +class BitMaskShapeParameter(PerTensorScaleParameter): + """ + Parameter class for the shape of the bitmask tensor. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _load_into_shard_id(self, loaded_weight: torch.Tensor, + shard_id: Union[str, int], **kwargs): + """ + Slice the parameter data based on the shard id for + loading. + + Note: Assumes the loaded weight is a 1D tensor + with 2 elements. + """ + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + start_index = shard_id * 2 + param_data[start_index:start_index + 2].copy_(loaded_weight) From d4b955bcd75d6c18a18e8f017075c496b0baf5c1 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 15 Jan 2025 20:58:21 +0000 Subject: [PATCH 02/12] Fix: mypy errors Signed-off-by: Rahul Tuli --- .../compressed_tensors/compressed_tensors.py | 10 +++++++--- .../schemes/compressed_tensors_24.py | 10 +++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 98e34b945ac24..3f22b4da2c4d4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -482,9 +482,13 @@ def supports_cutlass_24( :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise """ - is_valid_sparsity_structure = (sparsity_scheme is not None - and sparsity_scheme.sparsity_structure - == SparsityStructure.TWO_FOUR.value) + if sparsity_scheme is None: + return False + + is_valid_sparsity_structure: bool = ( + sparsity_scheme.sparsity_structure == + SparsityStructure.TWO_FOUR.value) + valid_compressors = { CompressionFormat.dense.value, CompressionFormat.sparse_24_bitmask.value diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 21333003d821d..a367de8808690 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from compressed_tensors import CompressionFormat, ModelCompressor @@ -270,9 +270,9 @@ def _process_split(bitmask_compressed_weight: torch.Tensor, shape, ) return sparsity_compressor.decompress_weight(weight_data) - split_weights = None - split_bitmask = None - split_shape = None + split_weights: List[torch.Tensor] = [] + split_bitmask: List[torch.Tensor] = [] + split_shape: List[Tuple[int, int]] = [] if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): split_weights = torch.split(compressed, layer.logical_widths) @@ -280,7 +280,7 @@ def _process_split(bitmask_compressed_weight: torch.Tensor, shape, split_shape = [(out, layer.input_size_per_partition) for out in layer.logical_widths] - if split_weights is not None: + if split_weights: decompressed_shards = [ _process_split(compressed_weight, shape, bitmask) for compressed_weight, shape, bitmask in zip( From 061ae5b0a7023c1870200fbbe4a7fda9c1a48a40 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 22 Jan 2025 19:29:14 +0000 Subject: [PATCH 03/12] Removed BitmaskShape Parameter Renamed `compressed` to `compressed_weight` Address review commits from @dsikka Signed-off-by: Rahul Tuli --- .../compressed_tensors/compressed_tensors.py | 20 +++-------- .../schemes/compressed_tensors_24.py | 33 +++++++++---------- vllm/model_executor/parameter.py | 26 +-------------- 3 files changed, 22 insertions(+), 57 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 3f22b4da2c4d4..7db8b18e9b08e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -420,12 +420,15 @@ def get_scheme(self, return None # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel + model_compression_config = (None if sparsity_scheme is None + or sparsity_scheme.format == "dense" + else self.config) + scheme = CompressedTensors24( quantized=weight_quant is not None or input_quant is not None, weight_quant=weight_quant, input_quant=input_quant, - model_compression_config=self._get_model_compression_config( - sparsity_scheme), + model_compression_config=model_compression_config, ) elif weight_quant is None: logger.warning_once("Acceleration for non-quantized schemes is " @@ -527,19 +530,6 @@ def supports_cutlass_24( return weight_quant.num_bits == input_quant.num_bits == 8 - def _get_model_compression_config( - self, sparsity_scheme: Optional[SparsityCompressionConfig] = None): - """ - Get the model compressor config from the sparsity scheme - - :param sparsity_scheme: The sparsity scheme - :return: The model compressor config - """ - if sparsity_scheme is None or sparsity_scheme.format == "dense": - return None - - return self.config - class CompressedTensorsLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index a367de8808690..88bd68dbc2f46 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -17,7 +17,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, sparse_cutlass_supported) from vllm.model_executor.parameter import (BasevLLMParameter, - BitMaskShapeParameter, ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -77,21 +76,21 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_dim=0, weight_loader=weight_loader) if self.do_sparse_decompress: - assert all( - partition_size % 8 == 0 - for partition_size in output_partition_sizes - ), "All partitions must be divisible by 8 for 2:4 compressed models" - - shape = BitMaskShapeParameter(data=torch.empty( - 2 * len(output_partition_sizes), 1, dtype=torch.uint64), - weight_loader=weight_loader) - compressed = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // 2, - dtype=self.weights_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + assert all(partition_size % 8 == 0 + for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for " + "2:4 sparse compressed models" + + shape = BasevLLMParameter(data=torch.empty(2, 1, + dtype=torch.int64), + weight_loader=weight_loader) + compressed_weight = ModelWeightParameter( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=self.weights_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) bitmask = ModelWeightParameter(data=torch.empty( sum(output_partition_sizes), @@ -102,7 +101,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, weight_loader=weight_loader) layer.register_parameter("shape", shape) - layer.register_parameter("compressed", compressed) + layer.register_parameter("compressed", compressed_weight) layer.register_parameter("bitmask", bitmask) # Check if quantized, not just 2:4 Sparse diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 65dadf5dcad4d..2b1294bf7baa3 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -13,8 +13,7 @@ __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter", - "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", - "BitMaskShapeParameter" + "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" ] logger = init_logger(__name__) @@ -432,26 +431,3 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_offset=shard_offset, marlin_tile_size=marlin_tile_size) return shard_size, shard_offset - - -class BitMaskShapeParameter(PerTensorScaleParameter): - """ - Parameter class for the shape of the bitmask tensor. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def _load_into_shard_id(self, loaded_weight: torch.Tensor, - shard_id: Union[str, int], **kwargs): - """ - Slice the parameter data based on the shard id for - loading. - - Note: Assumes the loaded weight is a 1D tensor - with 2 elements. - """ - param_data = self.data - shard_id = self._shard_id_as_int(shard_id) - start_index = shard_id * 2 - param_data[start_index:start_index + 2].copy_(loaded_weight) From f21edb73f6e24699e1327638fb342a6184d76d23 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 22 Jan 2025 21:15:28 +0000 Subject: [PATCH 04/12] Add: lm-eval, fp8, int8 tests Signed-off-by: Rahul Tuli --- .../SparseLlama3.1_2of4_fp8_compressed.yaml | 11 +++ tests/quantization/test_compressed_tensors.py | 73 ++++++++++++++++++- .../schemes/compressed_tensors_24.py | 12 +++ 3 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 .buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml diff --git a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml new file mode 100644 index 0000000000000..2928d75ce4469 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2 +model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.6353 + - name: "exact_match,flexible-extract" + value: 0.637 +limit: null +num_fewshot: null diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 7e2e6f6ed5890..3d9ec9d489732 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -243,7 +243,10 @@ def test_compressed_tensors_kv_cache(vllm_runner): @pytest.mark.skipif(not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.") -def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): +def _test_2of4_quant_models(qkv_proj, + weight_strategy, + input_strategy, + format="dense"): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) @@ -252,7 +255,7 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): assert qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 - assert sparsity_map.get("Linear").format == "dense" + assert sparsity_map.get("Linear").format == format assert sparsity_map.get("Linear").sparsity_structure == "2:4" @@ -286,6 +289,72 @@ def check_model(model): assert output +@pytest.mark.skipif(not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.") +@pytest.mark.parametrize("args_2of4", [ + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM", + "channel", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM", + "channel", "tensor"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM", + "tensor", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM", + "tensor", "tensor"), +]) +def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4): + model, weight_strategy, input_strategy = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn + _test_2of4_quant_models(qkv_proj, + weight_strategy, + input_strategy, + format="sparse-24-bitmask") + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + +@pytest.mark.skipif(not sparse_cutlass_supported(), + reason="cutlass is not yet supported on this GPU type.") +@pytest.mark.parametrize("args_2of4", [ + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM", + "channel", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM", + "channel", "tensor"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM", + "tensor", "token"), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM", + "tensor", "tensor"), +]) +def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4): + model, weight_strategy, input_strategy = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert qkv_proj.scheme.weights_dtype == torch.int8 + _test_2of4_quant_models(qkv_proj, + weight_strategy, + input_strategy, + format="sparse-24-bitmask") + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + @pytest.mark.skipif(not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.") @pytest.mark.parametrize("args_2of4", [ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 88bd68dbc2f46..e2106da0c9815 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -257,6 +257,18 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: def _decompress_bitmask_compressed_weight( self, compressed: torch.Tensor, bitmask: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor: + """ + Decompress a compressed 2:4 sparse weight tensor + using the bitmask and return the result. + + This function also supports sharded decompression. + + :param compressed: The 2:4 sparse weight tensor + compressed using the sparse-24-bitmask compressor. + :param bitmask: The 2:4 bitmask associated with the compressed weights. + :param layer: The layer whose weights need to be processed after loading. + :return: The decompressed 2:4 sparse weight tensor. + """ sparsity_compressor = self.model_compressor.sparsity_compressor From 132f7bf0897bc370aba26c99474ebe8eaec5531f Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 22 Jan 2025 21:35:44 +0000 Subject: [PATCH 05/12] Add: 2:4 Sparse only compressed test Signed-off-by: Rahul Tuli --- tests/quantization/test_compressed_tensors.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 3d9ec9d489732..db21e74102897 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -415,3 +415,35 @@ def check_model(model): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +@pytest.mark.skipif(not sparse_cutlass_supported(), + reason="Cutlass is not yet supported on this GPU type.") +@pytest.mark.parametrize( + "args_2of4", + [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]) +def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): + model = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensors24) + + assert qkv_proj.scheme.weight_quant is None + assert qkv_proj.scheme.input_quant is None + assert not qkv_proj.scheme.quantized + assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map + sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 + assert sparsity_map.get("Linear").format == "sparse-24-bitmask" + assert sparsity_map.get("Linear").sparsity_structure == "2:4" + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output \ No newline at end of file From 5e8a118ded50e79c4efd789c23990c68d770fb36 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 22 Jan 2025 21:39:31 +0000 Subject: [PATCH 06/12] Lint Signed-off-by: Rahul Tuli --- .../compressed_tensors/schemes/compressed_tensors_24.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index e2106da0c9815..4f183ac585c30 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -266,7 +266,8 @@ def _decompress_bitmask_compressed_weight( :param compressed: The 2:4 sparse weight tensor compressed using the sparse-24-bitmask compressor. :param bitmask: The 2:4 bitmask associated with the compressed weights. - :param layer: The layer whose weights need to be processed after loading. + :param layer: The layer whose weights need to be processed + after loading. :return: The decompressed 2:4 sparse weight tensor. """ From d1806c5bef1cbb84054d33053bf01c8e3f329a3b Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 23 Jan 2025 14:48:49 +0000 Subject: [PATCH 07/12] Delete: compression params after decompression Signed-off-by: Rahul Tuli --- .../compressed_tensors/schemes/compressed_tensors_24.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 4f183ac585c30..db1b1b60556bf 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -164,6 +164,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: bitmask=layer.bitmask, layer=layer, ) + # compressed and bitmask tensors + # are no longer needed after decompression + + del layer.compressed + del layer.bitmask # torch.compile workaround if hasattr(layer, "input_scale"): From ce45cf98c00bf3be0d042d04191ec9332f3a020b Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 23 Jan 2025 19:06:02 +0000 Subject: [PATCH 08/12] Update: Compressed Tensors version Signed-off-by: Rahul Tuli --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 97e33a6dbd880..cfa02025629f2 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -34,6 +34,6 @@ pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.9.0 # required for compressed-tensors +compressed-tensors == 0.9.1 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py From c0c99b12cb475b98076aca5d70621896e8a35f89 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 28 Jan 2025 13:06:40 -0500 Subject: [PATCH 09/12] Update vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py Co-authored-by: Tyler Michael Smith Signed-off-by: Rahul Tuli --- .../compressed_tensors/schemes/compressed_tensors_24.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index db1b1b60556bf..b4ed7bcf6f33e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -164,9 +164,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: bitmask=layer.bitmask, layer=layer, ) + # compressed and bitmask tensors # are no longer needed after decompression - del layer.compressed del layer.bitmask From cae85ea8eb4b0203226f9170eeb8db0b74bae295 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 30 Jan 2025 18:40:29 +0000 Subject: [PATCH 10/12] Address: Review comments Signed-off-by: Rahul Tuli --- .../schemes/compressed_tensors_24.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index b4ed7bcf6f33e..9c92488f095bb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -263,16 +263,21 @@ def _decompress_bitmask_compressed_weight( self, compressed: torch.Tensor, bitmask: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor: """ - Decompress a compressed 2:4 sparse weight tensor - using the bitmask and return the result. - + Decompress a compressed 2:4 sparse weight tensor using the bitmask and + return the result. + This function also supports sharded decompression. - :param compressed: The 2:4 sparse weight tensor - compressed using the sparse-24-bitmask compressor. - :param bitmask: The 2:4 bitmask associated with the compressed weights. - :param layer: The layer whose weights need to be processed - after loading. + :param compressed: The 2:4 sparse weight tensor compressed using the + sparse-24-bitmask compressor. This is different from + `cutlass_sparse_compress` which uses a different scheme (2 bits for + every nonzero element that represent the coordinate within the block + of 4). The bitmask compression here uses a bitmask to indicate the + positions of non-zero elements. + :param bitmask: The 2:4 bitmask associated with the compressed weights, + representing the positions of non-zero elements in the compressed + tensor. + :param layer: The layer whose weights need to be processed after loading. :return: The decompressed 2:4 sparse weight tensor. """ From e91e98b3c431534c3aace55d47c8f42de53c3eea Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 30 Jan 2025 18:49:16 +0000 Subject: [PATCH 11/12] pre-commit hooks Signed-off-by: Rahul Tuli Signed-off-by: Rahul Tuli --- .../schemes/compressed_tensors_24.py | 174 +++++++++++------- 1 file changed, 104 insertions(+), 70 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 9c92488f095bb..0fb8dfa96a19c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -33,7 +33,6 @@ def __init__( input_quant: Optional[QuantizationArgs] = None, model_compression_config: Optional[Dict[str, Any]] = None, ): - self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant @@ -50,12 +49,16 @@ def get_min_capability(cls) -> int: # Only cutlass 3.x kernels are implemented so far return 90 - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - + def create_weights( + self, + layer: torch.nn.Module, + input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): if not sparse_cutlass_supported(): raise ValueError( "Sparse CUTLASS not supported. vLLM must be built with " @@ -68,37 +71,47 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) # parameter to store uncompressed weight - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=self.weights_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=self.weights_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) if self.do_sparse_decompress: assert all(partition_size % 8 == 0 for partition_size in output_partition_sizes ), "All partitions must be divisible by 8 for " "2:4 sparse compressed models" - shape = BasevLLMParameter(data=torch.empty(2, 1, - dtype=torch.int64), - weight_loader=weight_loader) + shape = BasevLLMParameter( + data=torch.empty(2, 1, dtype=torch.int64), + weight_loader=weight_loader, + ) compressed_weight = ModelWeightParameter( - data=torch.empty(sum(output_partition_sizes), - input_size_per_partition // 2, - dtype=self.weights_dtype), + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=self.weights_dtype, + ), input_dim=1, output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - bitmask = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // 8, - dtype=torch.uint8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + bitmask = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 8, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("shape", shape) layer.register_parameter("compressed", compressed_weight) @@ -112,14 +125,16 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) else: assert (self.weight_quant and self.weight_quant.strategy == QuantizationStrategy.TENSOR.value) weight_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) @@ -128,9 +143,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, # register input quant scale assert (self.input_quant.strategy == QuantizationStrategy.TENSOR.value) - input_scale = BasevLLMParameter(data=torch.empty( - 1, dtype=torch.float32), - weight_loader=weight_loader) + input_scale = BasevLLMParameter( + data=torch.empty(1, dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("input_scale", input_scale) @@ -151,12 +167,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: """ Compress weights after loading. Store compressed weight and meta tensor - + :post-condition: layer.w_compressed and layer.meta are set to the compressed weight and meta tensor in the format expected by the Cutlass kernels :param layer: The layer with the weights to be processed - + """ if self.do_sparse_decompress: layer.weight.data = self._decompress_bitmask_compressed_weight( @@ -177,10 +193,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.weight_quant: if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: - layer.weight_scale = torch.nn.Parameter(convert_to_channelwise( - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths), - requires_grad=False) + layer.weight_scale = torch.nn.Parameter( + convert_to_channelwise( + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ), + requires_grad=False, + ) else: # torch.compile workaround layer.weight_scale = torch.nn.Parameter( @@ -190,20 +209,22 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False) layer.meta = torch.nn.Parameter(meta, requires_grad=False) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ - Returns the output tensor for the layer with 2:4 + Returns the output tensor for the layer with 2:4 sparse compressed weights, given the input tensor and bias - :param layer: The layer with 2:4 sparse compressed + :param layer: The layer with 2:4 sparse compressed weights to be used for the computation :param x: The input tensor to the layer :param bias: The bias to be added to the output tensor - :return: The output tensor of the layer + :return: The output tensor of the layer """ if self.quantized: scale = None @@ -227,13 +248,15 @@ def apply_weights(self, input_scale = layer.input_scale q_input = x - out = ops.cutlass_scaled_sparse_mm(a=q_input, - bt_nzs=layer.weight, - bt_meta=layer.meta, - scale_a=input_scale, - scale_b=layer.weight_scale, - out_dtype=self.output_dtype, - bias=bias) + out = ops.cutlass_scaled_sparse_mm( + a=q_input, + bt_nzs=layer.weight, + bt_meta=layer.meta, + scale_a=input_scale, + scale_b=layer.weight_scale, + out_dtype=self.output_dtype, + bias=bias, + ) assert out.is_contiguous() return out @@ -260,31 +283,38 @@ def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: raise ValueError("Quantization type not supported by Cutlass") def _decompress_bitmask_compressed_weight( - self, compressed: torch.Tensor, bitmask: torch.Tensor, - layer: torch.nn.Module) -> torch.Tensor: + self, + compressed: torch.Tensor, + bitmask: torch.Tensor, + layer: torch.nn.Module, + ) -> torch.Tensor: """ - Decompress a compressed 2:4 sparse weight tensor using the bitmask and + Decompress a compressed 2:4 sparse weight tensor using the bitmask and return the result. This function also supports sharded decompression. - :param compressed: The 2:4 sparse weight tensor compressed using the - sparse-24-bitmask compressor. This is different from - `cutlass_sparse_compress` which uses a different scheme (2 bits for - every nonzero element that represent the coordinate within the block - of 4). The bitmask compression here uses a bitmask to indicate the + :param compressed: The 2:4 sparse weight tensor compressed using the + sparse-24-bitmask compressor. This is different from + `cutlass_sparse_compress` which uses a different scheme (2 bits for + every nonzero element that represent the coordinate within the block + of 4). The bitmask compression here uses a bitmask to indicate the positions of non-zero elements. - :param bitmask: The 2:4 bitmask associated with the compressed weights, - representing the positions of non-zero elements in the compressed + :param bitmask: The 2:4 bitmask associated with the compressed weights, + representing the positions of non-zero elements in the compressed tensor. - :param layer: The layer whose weights need to be processed after loading. + :param layer: The layer whose weights need to be processed after + loading. :return: The decompressed 2:4 sparse weight tensor. """ sparsity_compressor = self.model_compressor.sparsity_compressor - def _process_split(bitmask_compressed_weight: torch.Tensor, shape, - bitmask: torch.Tensor) -> torch.Tensor: + def _process_split( + bitmask_compressed_weight: torch.Tensor, + shape, + bitmask: torch.Tensor, + ) -> torch.Tensor: weight_data = dict( compressed=bitmask_compressed_weight, shape=shape, @@ -311,8 +341,12 @@ def _process_split(bitmask_compressed_weight: torch.Tensor, shape, decompressed = combine_shards(decompressed_shards) else: decompressed = sparsity_compressor.decompress_weight( - dict(compressed=compressed, - shape=(layer.logical_widths[0], - layer.input_size_per_partition), - bitmask=bitmask)) + dict( + compressed=compressed, + shape=( + layer.logical_widths[0], + layer.input_size_per_partition, + ), + bitmask=bitmask, + )) return decompressed From 4cace0c7c24c264128b28e390642bf3501e6c7b9 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 30 Jan 2025 19:20:42 +0000 Subject: [PATCH 12/12] test-file ruff Signed-off-by: Rahul Tuli --- tests/quantization/test_compressed_tensors.py | 303 ++++++++++++------ 1 file changed, 211 insertions(+), 92 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index db21e74102897..0655f2b385f3a 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -3,6 +3,7 @@ Run `pytest tests/quantization/test_compressed_tensors.py`. """ + from typing import Optional import pytest @@ -22,12 +23,30 @@ @pytest.mark.parametrize( "model_args", - [("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor", - QuantizationType.INT, 2560, True), - ("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel", - QuantizationType.INT, 2560, True), - ("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor", - QuantizationType.INT, 2560, False)]) + [ + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + "tensor", + QuantizationType.INT, + 2560, + True, + ), + ( + "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", + "channel", + QuantizationType.INT, + 2560, + True, + ), + ( + "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", + "tensor", + QuantizationType.INT, + 2560, + False, + ), + ], +) def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): model_path, strategy, quant_type, shape_0, is_symmetric = model_args with vllm_runner(model_path, enforce_eager=True) as llm: @@ -85,21 +104,31 @@ def zp_valid(zp: Optional[torch.Tensor]): assert output -@pytest.mark.parametrize("model_path", [ - "neuralmagic/Llama-3.2-1B-quantized.w8a8", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" -]) +@pytest.mark.parametrize( + "model_path", + [ + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", + ], +) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner, - example_prompts, model_path, - max_tokens, num_logprobs): +def test_compressed_tensors_w8a8_logprobs( + hf_runner, + vllm_runner, + example_prompts, + model_path, + max_tokens, + num_logprobs, +): dtype = "bfloat16" # skip language translation prompt for the static per tensor asym model - if model_path == "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym": # noqa: E501 + if (model_path == + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" + ): # noqa: E501 example_prompts = example_prompts[0:-1] with hf_runner(model_path, dtype=dtype) as hf_model: @@ -125,13 +154,21 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner): assert output -@pytest.mark.parametrize("model_args", [ - ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"), - ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"), - ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"), - ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", - "channel"), -]) +@pytest.mark.parametrize( + "model_args", + [ + ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"), + ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"), + ( + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", + "channel", + ), + ( + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", + "channel", + ), + ], +) def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args): model_path, strategy = model_args with vllm_runner(model_path, dtype=torch.float16) as llm: @@ -156,9 +193,12 @@ def check_model(model): @pytest.mark.parametrize( "wNa16_args", - [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), - ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), - ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)]) + [ + ("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), + ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), + ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4), + ], +) def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): model, strategy, group, pack_factor = wNa16_args with vllm_runner(model) as llm: @@ -218,7 +258,8 @@ def check_model(model): CompressedTensorsLinearMethod) assert isinstance( qkv_proj.scheme, - (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8)) + (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8), + ) assert qkv_proj.input_scale.dtype is torch.float32 @@ -241,8 +282,10 @@ def test_compressed_tensors_kv_cache(vllm_runner): assert output -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse FP8 is not yet supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse FP8 is not yet supported on this GPU type.", +) def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy, @@ -259,18 +302,35 @@ def _test_2of4_quant_models(qkv_proj, assert sparsity_map.get("Linear").sparsity_structure == "2:4" -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="Sparse FP8 is not yet supported on this GPU type.") -@pytest.mark.parametrize("args_2of4", [ - ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", - "token"), - ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", - "channel", "tensor"), - ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor", - "tensor"), - ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", - "tensor", "token"), -]) +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize( + "args_2of4", + [ + ( + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", + "channel", + "token", + ), + ( + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", + "channel", + "tensor", + ), + ( + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", + "tensor", + "tensor", + ), + ( + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", + "tensor", + "token", + ), + ], +) def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 with vllm_runner(model) as llm: @@ -289,18 +349,35 @@ def check_model(model): assert output -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="Sparse FP8 is not yet supported on this GPU type.") -@pytest.mark.parametrize("args_2of4", [ - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM", - "channel", "token"), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM", - "channel", "tensor"), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM", - "tensor", "token"), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM", - "tensor", "tensor"), -]) +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize( + "args_2of4", + [ + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM", + "channel", + "token", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM", + "channel", + "tensor", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM", + "tensor", + "token", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM", + "tensor", + "tensor", + ), + ], +) def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 with vllm_runner(model) as llm: @@ -310,10 +387,12 @@ def check_model(model): qkv_proj = layer.self_attn.qkv_proj assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn - _test_2of4_quant_models(qkv_proj, - weight_strategy, - input_strategy, - format="sparse-24-bitmask") + _test_2of4_quant_models( + qkv_proj, + weight_strategy, + input_strategy, + format="sparse-24-bitmask", + ) llm.apply_model(check_model) @@ -322,18 +401,35 @@ def check_model(model): assert output -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="cutlass is not yet supported on this GPU type.") -@pytest.mark.parametrize("args_2of4", [ - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM", - "channel", "token"), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM", - "channel", "tensor"), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM", - "tensor", "token"), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM", - "tensor", "tensor"), -]) +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="cutlass is not yet supported on this GPU type.", +) +@pytest.mark.parametrize( + "args_2of4", + [ + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM", + "channel", + "token", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM", + "channel", + "tensor", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM", + "tensor", + "token", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM", + "tensor", + "tensor", + ), + ], +) def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 with vllm_runner(model) as llm: @@ -343,10 +439,12 @@ def check_model(model): qkv_proj = layer.self_attn.qkv_proj assert qkv_proj.scheme.weights_dtype == torch.int8 - _test_2of4_quant_models(qkv_proj, - weight_strategy, - input_strategy, - format="sparse-24-bitmask") + _test_2of4_quant_models( + qkv_proj, + weight_strategy, + input_strategy, + format="sparse-24-bitmask", + ) llm.apply_model(check_model) @@ -355,16 +453,30 @@ def check_model(model): assert output -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse FP8 is not yet supported on this GPU type.") -@pytest.mark.parametrize("args_2of4", [ - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", - "channel", "token"), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor", - "tensor"), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing", - "tensor", "token"), -]) +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize( + "args_2of4", + [ + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", + "channel", + "token", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", + "tensor", + "tensor", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing", + "tensor", + "token", + ), + ], +) def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 with vllm_runner(model) as llm: @@ -386,10 +498,12 @@ def check_model(model): @pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.") @pytest.mark.skipif( not sparse_cutlass_supported(), - reason="2of4 Sparse is not yet supported on this GPU type.") + reason="2of4 Sparse is not yet supported on this GPU type.", +) @pytest.mark.parametrize( "args_2of4", - [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")]) + [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")], +) def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4): model = args_2of4 with vllm_runner(model) as llm: @@ -406,7 +520,9 @@ def check_model(model): assert qkv_proj.scheme.input_quant is None assert not qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map - sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 + sparsity_map = ( + qkv_proj.quant_method.quantization_config.sparsity_scheme_map + ) # noqa: E501 assert sparsity_map.get("Linear").format == "dense" assert sparsity_map.get("Linear").sparsity_structure == "2:4" @@ -417,11 +533,12 @@ def check_model(model): assert output -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Cutlass is not yet supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Cutlass is not yet supported on this GPU type.", +) @pytest.mark.parametrize( - "args_2of4", - [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]) + "args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]) def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): model = args_2of4 with vllm_runner(model) as llm: @@ -438,7 +555,9 @@ def check_model(model): assert qkv_proj.scheme.input_quant is None assert not qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map - sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 + sparsity_map = ( + qkv_proj.quant_method.quantization_config.sparsity_scheme_map + ) # noqa: E501 assert sparsity_map.get("Linear").format == "sparse-24-bitmask" assert sparsity_map.get("Linear").sparsity_structure == "2:4" @@ -446,4 +565,4 @@ def check_model(model): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) - assert output \ No newline at end of file + assert output