diff --git a/src/compressed_tensors/compressors/__init__.py b/src/compressed_tensors/compressors/__init__.py index b6f2c7d6..d3bb61f5 100644 --- a/src/compressed_tensors/compressors/__init__.py +++ b/src/compressed_tensors/compressors/__init__.py @@ -18,6 +18,7 @@ from .dense import DenseCompressor from .helpers import load_compressed, save_compressed, save_compressed_model from .int_quantized import IntQuantizationCompressor -from .model_compressor import ModelCompressor +from .marlin_24 import Marlin24Compressor +from .model_compressor import ModelCompressor, map_modules_to_quant_args from .pack_quantized import PackedQuantizationCompressor from .sparse_bitmask import BitmaskCompressor, BitmaskTensor diff --git a/src/compressed_tensors/compressors/marlin_24.py b/src/compressed_tensors/compressors/marlin_24.py new file mode 100644 index 00000000..1e34594a --- /dev/null +++ b/src/compressed_tensors/compressors/marlin_24.py @@ -0,0 +1,250 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Generator, Tuple + +import numpy as np +import torch +from compressed_tensors.compressors import Compressor +from compressed_tensors.compressors.utils import ( + get_permutations_24, + sparse_semi_structured_from_dense_cutlass, + tensor_follows_mask_structure, +) +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization.lifecycle.forward import quantize +from compressed_tensors.utils import is_quantization_param, merge_names +from torch import Tensor +from tqdm import tqdm + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +@Compressor.register(name=CompressionFormat.marlin_24.value) +class Marlin24Compressor(Compressor): + """ + Compresses a quantized model with 2:4 sparsity structure for inference with the + Marlin24 kernel. Decompression is not implemented for this compressor. + """ + + COMPRESSION_PARAM_NAMES = ["weight_packed", "scale_packed", "meta"] + + @staticmethod + def validate_quant_compatability( + model_quant_args: Dict[str, QuantizationArgs] + ) -> bool: + """ + Checks if every quantized module in the model is compatible with Marlin24 + compression. Quantization must be channel or group strategy with group_size + of 128. Only symmetric quantization is supported + + :param model_quant_args: dictionary of mapping module names to their + quantization configuration + :return: True if all modules are compatible with Marlin24 compression, raises + a ValueError otherwise + """ + for name, quant_args in model_quant_args.items(): + strategy = quant_args.strategy + group_size = quant_args.group_size + symmetric = quant_args.symmetric + if ( + strategy is not QuantizationStrategy.GROUP + and strategy is not QuantizationStrategy.CHANNEL + ): + raise ValueError( + f"Marlin24 Compressor is only valid for group and channel " + f"quantization strategies, got {strategy} in {name}" + ) + + if group_size is not None and group_size != 128: + raise ValueError( + f"Marlin24 Compressor is only valid for group size 128, " + f"got {group_size} in {name}" + ) + + if not symmetric: + raise ValueError( + f"Marlin24 Compressor is only valid for symmetric quantzation, " + f"got symmetric={symmetric} in {name}" + ) + + return True + + @staticmethod + def validate_sparsity_structure(name: str, weight: Tensor) -> bool: + """ + Checks if a tensor fits the required 2:4 sparsity structure + + :param name: name of the tensor to check + :param weight: tensor to check for sparsity structure + :return: True if all rows match the 2:4 sparsity structure, raises + ValueError otherwise + """ + + if not tensor_follows_mask_structure(weight): + raise ValueError( + "Marlin24 Compressor is only compatible with weights that have " + f"a 2:4 sparsity structure. Found segments in {name} " + "that do not match the expected structure." + ) + + return True + + def compress( + self, + model_state: Dict[str, Tensor], + model_quant_args: Dict[str, QuantizationArgs], + **kwargs, + ) -> Dict[str, Tensor]: + """ + Compresses a quantized state_dict with 2:4 sparsity structure for inference + with the Marlin24 kernel + + :param model_state: state dict of uncompressed model + :param model_quant_args: quantization args for each quantized weight, needed for + quantize function to calculate bit depth + :return: compressed state dict + """ + self.validate_quant_compatability(model_quant_args) + + compressed_dict = {} + weight_suffix = ".weight" + _LOGGER.debug( + f"Compressing model with {len(model_state)} parameterized layers..." + ) + + for name, value in tqdm(model_state.items(), desc="Compressing model"): + if name.endswith(weight_suffix): + prefix = name[: -(len(weight_suffix))] + scale = model_state.get(merge_names(prefix, "weight_scale"), None) + zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) + if scale is not None: # weight is quantized, compress it + + # Marlin24 kernel requires float16 inputs + scale = scale.to(torch.float16) + value = value.to(torch.float16) + + # quantize weight, keeping it as a float16 for now + quant_args = model_quant_args[prefix] + value = quantize( + x=value, scale=scale, zero_point=zp, args=quant_args + ) + + # compress based on sparsity structure + self.validate_sparsity_structure(prefix, value) + value, meta = compress_weight_24(value) + meta = meta.cpu() + + # Marlin24 kernel expects input dim first + value = value.t().contiguous().cpu() + scale = scale.t().contiguous().cpu() + og_weight_shape = value.shape + + # Marlin24 kernel expects unsigned values, shift zero-point + value += (1 << quant_args.num_bits) // 2 + + # pack quantized weight and scale + value = pack_weight_24(value, quant_args) + packed_scale = pack_scales_24(scale, quant_args, og_weight_shape) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + # save compressed values + compressed_dict[merge_names(prefix, "scale_packed")] = packed_scale + compressed_dict[merge_names(prefix, "weight_packed")] = value + compressed_dict[merge_names(prefix, "meta")] = meta + continue + + if not is_quantization_param(name): + # export unquantized parameters without modifying + compressed_dict[name] = value.to("cpu") + + return compressed_dict + + def decompress( + self, path_to_model_or_tensors: str, device: str = "cpu" + ) -> Generator[Tuple[str, Tensor], None, None]: + raise NotImplementedError( + "Decompression is not implemented for the Marlin24 Compressor." + ) + + +def compress_weight_24(weight: Tensor): + weight = weight.contiguous() + w_comp, meta = sparse_semi_structured_from_dense_cutlass(weight) + w_comp = w_comp.contiguous() + return w_comp, meta + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def pack_weight_24( + weight: Tensor, + quantization_args: QuantizationArgs, + tile: int = 16, +): + size_k = weight.shape[0] + size_n = weight.shape[1] + num_bits = quantization_args.num_bits + pack_factor = 32 // num_bits + + # Reshuffle to marlin_24 format + perm, _, _ = get_permutations_24(num_bits) + q_w = marlin_permute_weights(weight, size_k, size_n, perm, tile) + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)) + + return q_packed + + +def pack_scales_24(scales, quantization_args, w_shape): + size_k = w_shape[0] + size_n = w_shape[1] + num_bits = quantization_args.num_bits + + _, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits) + + if ( + quantization_args.strategy is QuantizationStrategy.GROUP + and quantization_args.group_size < size_k + ): + scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4] + else: # channelwise + scales = scales.reshape((-1, len(scale_perm_single_2_4)))[ + :, scale_perm_single_2_4 + ] + scales = scales.reshape((-1, size_n)).contiguous() + + return scales diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index 64ad14d8..b1df83dc 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -45,7 +45,7 @@ from transformers.file_utils import CONFIG_NAME -__all__ = ["ModelCompressor"] +__all__ = ["ModelCompressor", "map_modules_to_quant_args"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -190,7 +190,7 @@ def compress( state_dict = model.state_dict() compressed_state_dict = state_dict - quantized_modules_to_args = _get_weight_arg_mappings(model) + quantized_modules_to_args = map_modules_to_quant_args(model) if self.quantization_compressor is not None: compressed_state_dict = self.quantization_compressor.compress( state_dict, model_quant_args=quantized_modules_to_args @@ -269,7 +269,7 @@ def _replace_weights(self, dense_weight_generator, model): data_old.data = data_new.data -def _get_weight_arg_mappings(model: Module) -> Dict: +def map_modules_to_quant_args(model: Module) -> Dict: quantized_modules_to_args = {} for name, submodule in iter_named_leaf_modules(model): if is_module_quantized(submodule): diff --git a/src/compressed_tensors/compressors/utils/__init__.py b/src/compressed_tensors/compressors/utils/__init__.py new file mode 100644 index 00000000..e78a9969 --- /dev/null +++ b/src/compressed_tensors/compressors/utils/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .helpers import * +from .permutations_24 import * +from .semi_structured_conversions import * diff --git a/src/compressed_tensors/compressors/utils/helpers.py b/src/compressed_tensors/compressors/utils/helpers.py new file mode 100644 index 00000000..2a23ae6a --- /dev/null +++ b/src/compressed_tensors/compressors/utils/helpers.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +__all__ = ["tensor_follows_mask_structure"] + + +def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool: + """ + :param tensor: tensor to check + :param mask: mask structure to check for, in the format "n:m" + :return: True if the tensor follows the mask structure, False otherwise. + Note, some weights can incidentally be zero, so we check for + atleast n zeros in each chunk of size m + """ + + n, m = tuple(map(int, mask.split(":"))) + # Reshape the tensor into chunks of size m + tensor = tensor.view(-1, m) + + # Count the number of zeros in each chunk + zero_counts = (tensor == 0).sum(dim=1) + + # Check if the number of zeros in each chunk atleast n + # Greater than sign is needed as some weights can incidentally + # be zero + if not torch.all(zero_counts >= n).item(): + raise ValueError() + + return True diff --git a/src/compressed_tensors/compressors/utils/permutations_24.py b/src/compressed_tensors/compressors/utils/permutations_24.py new file mode 100644 index 00000000..5b078e27 --- /dev/null +++ b/src/compressed_tensors/compressors/utils/permutations_24.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy +import torch + + +__all__ = ["get_permutations_24"] + + +# Precompute permutations for Marlin24 weight and scale shuffling +# Originally implemented in nm-vllm/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py # noqa: E501 +# +# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight +# data so that it is compatible with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data +# as it is needed for tensor-core (without the need to use ldmatrix instructions) +def get_permutations_24(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return perm, scale_perm, scale_perm_single diff --git a/src/compressed_tensors/compressors/utils/semi_structured_conversions.py b/src/compressed_tensors/compressors/utils/semi_structured_conversions.py new file mode 100644 index 00000000..79afce19 --- /dev/null +++ b/src/compressed_tensors/compressors/utils/semi_structured_conversions.py @@ -0,0 +1,341 @@ +# +# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es). +# Pulled from nm-vllm/vllm/model_executor/layers/quantization/utils/format_24.py +# +# flake8: noqa +# isort: skip_file + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +__all__ = [ + "sparse_semi_structured_from_dense_cutlass", + "sparse_semi_structured_to_dense_cutlass", + "mask_creator", +] + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index e2b0a97e..b9ecab8e 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -27,6 +27,7 @@ class CompressionFormat(Enum): sparse_bitmask = "sparse-bitmask" int_quantized = "int-quantized" pack_quantized = "pack-quantized" + marlin_24 = "marlin-24" class SparsityCompressionConfig(RegistryMixin, BaseModel): diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index ee8e6ddd..9cdac782 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -31,6 +31,7 @@ "get_weight_mappings", "get_nested_weight_mappings", "get_quantization_state_dict", + "is_quantization_param", ] @@ -214,7 +215,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]: weight_mappings = get_weight_mappings(model_path) state_dict = {} for weight_name, safe_path in weight_mappings.items(): - if not _is_quantization_weight(weight_name): + if not is_quantization_param(weight_name): continue with safe_open(safe_path, framework="pt", device="cpu") as f: state_dict[weight_name] = f.get_tensor(weight_name) @@ -222,7 +223,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]: return state_dict -def _is_quantization_weight(name: str) -> bool: +def is_quantization_param(name: str) -> bool: """ Checks is a parameter name is associated with a quantization parameter diff --git a/tests/test_compressors/test_marlin_24.py b/tests/test_compressors/test_marlin_24.py new file mode 100644 index 00000000..0e89b9d8 --- /dev/null +++ b/tests/test_compressors/test_marlin_24.py @@ -0,0 +1,106 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import pytest +import torch +from compressed_tensors.compressors import ( + Compressor, + Marlin24Compressor, + map_modules_to_quant_args, +) +from compressed_tensors.compressors.utils import mask_creator +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationScheme, + QuantizationStatus, + QuantizationStrategy, + apply_quantization_config, + apply_quantization_status, +) +from compressed_tensors.utils import merge_names +from torch.nn.modules import Linear, Sequential + + +def get_2_4_quant_config(num_bits, strategy, ignore): + gs = 128 if strategy is QuantizationStrategy.GROUP else None + weights = QuantizationArgs(num_bits=num_bits, strategy=strategy, group_size=gs) + scheme = QuantizationScheme(weights=weights, targets=["Linear"]) + config = QuantizationConfig(config_groups={"group_0": scheme}, ignore=ignore) + return config + + +def test_marlin_registered(): + config_name = CompressionFormat.marlin_24.value + compressor = Compressor.load_from_registry(config_name) + assert isinstance(compressor, Marlin24Compressor) + + +@pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize( + "strategy", [QuantizationStrategy.GROUP, QuantizationStrategy.CHANNEL] +) +@pytest.mark.parametrize("layer_shape", [(512, 128), (1024, 1024), (128, 256)]) +def test_marlin24_format(num_bits, strategy, layer_shape): + QUANT_NAME = "quant" + NOT_QUANT_NAME = "not_quant" + model = Sequential( + OrderedDict( + [ + (QUANT_NAME, Linear(layer_shape[0], layer_shape[1], bias=False)), + (NOT_QUANT_NAME, Linear(layer_shape[1], 64, bias=False)), + ] + ) + ) + config = get_2_4_quant_config(num_bits, strategy, ignore=[NOT_QUANT_NAME]) + mask = mask_creator(model.quant.weight.data).bool() + model.quant.weight.data *= mask + + apply_quantization_config(model, config) + apply_quantization_status(model, QuantizationStatus.CALIBRATION) + + # runs observer to get scale and zero point + input = torch.rand((64, layer_shape[0])) + _ = model(input) + + state_dict = model.state_dict() + assert len(state_dict) == 4 + assert f"{NOT_QUANT_NAME}.weight_scale" not in state_dict + assert f"{QUANT_NAME}.weight_scale" in state_dict + + model_to_quant_args = map_modules_to_quant_args(model) + compressor = Marlin24Compressor() + compressor.validate_quant_compatability(model_to_quant_args) + compressor.validate_sparsity_structure( + QUANT_NAME, state_dict[f"{QUANT_NAME}.weight"] + ) + with pytest.raises(ValueError): + compressor.validate_sparsity_structure( + NOT_QUANT_NAME, state_dict[f"{NOT_QUANT_NAME}.weight"] + ) + + compressor = Marlin24Compressor() + compressed_state_dict = compressor.compress(state_dict, model_to_quant_args) + + assert len(compressed_state_dict) == 4 + assert torch.equal( + state_dict[f"{NOT_QUANT_NAME}.weight"], + compressed_state_dict[f"{NOT_QUANT_NAME}.weight"], + ) + for param_name in compressor.COMPRESSION_PARAM_NAMES: + full_param_name = merge_names(QUANT_NAME, param_name) + assert full_param_name in compressed_state_dict