Skip to content

Commit

Permalink
Marlin24 Compressor (vllm-project#77)
Browse files Browse the repository at this point in the history
* add marlin compressor

* full implementation running

* save meta matrix

* fix type

* change axis we pack on

* debug

* fix offset

* fix scale shape

* cast dtype

* cleanup

* fix

* avoid extra transpose

* add uncompressed weights

* unit tests

* cleanup

* Update src/compressed_tensors/compressors/marlin_24.py

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>

* Update src/compressed_tensors/compressors/marlin_24.py

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>

* style

* move helper fn

---------

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
  • Loading branch information
Sara Adkins and dbogunowicz authored Jun 11, 2024
1 parent a34618b commit 14b1db1
Show file tree
Hide file tree
Showing 10 changed files with 833 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .dense import DenseCompressor
from .helpers import load_compressed, save_compressed, save_compressed_model
from .int_quantized import IntQuantizationCompressor
from .model_compressor import ModelCompressor
from .marlin_24 import Marlin24Compressor
from .model_compressor import ModelCompressor, map_modules_to_quant_args
from .pack_quantized import PackedQuantizationCompressor
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
250 changes: 250 additions & 0 deletions src/compressed_tensors/compressors/marlin_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple

import numpy as np
import torch
from compressed_tensors.compressors import Compressor
from compressed_tensors.compressors.utils import (
get_permutations_24,
sparse_semi_structured_from_dense_cutlass,
tensor_follows_mask_structure,
)
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import quantize
from compressed_tensors.utils import is_quantization_param, merge_names
from torch import Tensor
from tqdm import tqdm


_LOGGER: logging.Logger = logging.getLogger(__name__)


@Compressor.register(name=CompressionFormat.marlin_24.value)
class Marlin24Compressor(Compressor):
"""
Compresses a quantized model with 2:4 sparsity structure for inference with the
Marlin24 kernel. Decompression is not implemented for this compressor.
"""

COMPRESSION_PARAM_NAMES = ["weight_packed", "scale_packed", "meta"]

@staticmethod
def validate_quant_compatability(
model_quant_args: Dict[str, QuantizationArgs]
) -> bool:
"""
Checks if every quantized module in the model is compatible with Marlin24
compression. Quantization must be channel or group strategy with group_size
of 128. Only symmetric quantization is supported
:param model_quant_args: dictionary of mapping module names to their
quantization configuration
:return: True if all modules are compatible with Marlin24 compression, raises
a ValueError otherwise
"""
for name, quant_args in model_quant_args.items():
strategy = quant_args.strategy
group_size = quant_args.group_size
symmetric = quant_args.symmetric
if (
strategy is not QuantizationStrategy.GROUP
and strategy is not QuantizationStrategy.CHANNEL
):
raise ValueError(
f"Marlin24 Compressor is only valid for group and channel "
f"quantization strategies, got {strategy} in {name}"
)

if group_size is not None and group_size != 128:
raise ValueError(
f"Marlin24 Compressor is only valid for group size 128, "
f"got {group_size} in {name}"
)

if not symmetric:
raise ValueError(
f"Marlin24 Compressor is only valid for symmetric quantzation, "
f"got symmetric={symmetric} in {name}"
)

return True

@staticmethod
def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
"""
Checks if a tensor fits the required 2:4 sparsity structure
:param name: name of the tensor to check
:param weight: tensor to check for sparsity structure
:return: True if all rows match the 2:4 sparsity structure, raises
ValueError otherwise
"""

if not tensor_follows_mask_structure(weight):
raise ValueError(
"Marlin24 Compressor is only compatible with weights that have "
f"a 2:4 sparsity structure. Found segments in {name} "
"that do not match the expected structure."
)

return True

def compress(
self,
model_state: Dict[str, Tensor],
model_quant_args: Dict[str, QuantizationArgs],
**kwargs,
) -> Dict[str, Tensor]:
"""
Compresses a quantized state_dict with 2:4 sparsity structure for inference
with the Marlin24 kernel
:param model_state: state dict of uncompressed model
:param model_quant_args: quantization args for each quantized weight, needed for
quantize function to calculate bit depth
:return: compressed state dict
"""
self.validate_quant_compatability(model_quant_args)

compressed_dict = {}
weight_suffix = ".weight"
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)

for name, value in tqdm(model_state.items(), desc="Compressing model"):
if name.endswith(weight_suffix):
prefix = name[: -(len(weight_suffix))]
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
if scale is not None: # weight is quantized, compress it

# Marlin24 kernel requires float16 inputs
scale = scale.to(torch.float16)
value = value.to(torch.float16)

# quantize weight, keeping it as a float16 for now
quant_args = model_quant_args[prefix]
value = quantize(
x=value, scale=scale, zero_point=zp, args=quant_args
)

# compress based on sparsity structure
self.validate_sparsity_structure(prefix, value)
value, meta = compress_weight_24(value)
meta = meta.cpu()

# Marlin24 kernel expects input dim first
value = value.t().contiguous().cpu()
scale = scale.t().contiguous().cpu()
og_weight_shape = value.shape

# Marlin24 kernel expects unsigned values, shift zero-point
value += (1 << quant_args.num_bits) // 2

# pack quantized weight and scale
value = pack_weight_24(value, quant_args)
packed_scale = pack_scales_24(scale, quant_args, og_weight_shape)
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)

# save compressed values
compressed_dict[merge_names(prefix, "scale_packed")] = packed_scale
compressed_dict[merge_names(prefix, "weight_packed")] = value
compressed_dict[merge_names(prefix, "meta")] = meta
continue

if not is_quantization_param(name):
# export unquantized parameters without modifying
compressed_dict[name] = value.to("cpu")

return compressed_dict

def decompress(
self, path_to_model_or_tensors: str, device: str = "cpu"
) -> Generator[Tuple[str, Tensor], None, None]:
raise NotImplementedError(
"Decompression is not implemented for the Marlin24 Compressor."
)


def compress_weight_24(weight: Tensor):
weight = weight.contiguous()
w_comp, meta = sparse_semi_structured_from_dense_cutlass(weight)
w_comp = w_comp.contiguous()
return w_comp, meta


def marlin_permute_weights(q_w, size_k, size_n, perm, tile):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"

# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))

q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)

return q_w


def pack_weight_24(
weight: Tensor,
quantization_args: QuantizationArgs,
tile: int = 16,
):
size_k = weight.shape[0]
size_n = weight.shape[1]
num_bits = quantization_args.num_bits
pack_factor = 32 // num_bits

# Reshuffle to marlin_24 format
perm, _, _ = get_permutations_24(num_bits)
q_w = marlin_permute_weights(weight, size_k, size_n, perm, tile)

q_w = q_w.cpu().numpy().astype(np.uint32)

q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i

q_packed = torch.from_numpy(q_packed.astype(np.int32))

return q_packed


def pack_scales_24(scales, quantization_args, w_shape):
size_k = w_shape[0]
size_n = w_shape[1]
num_bits = quantization_args.num_bits

_, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)

if (
quantization_args.strategy is QuantizationStrategy.GROUP
and quantization_args.group_size < size_k
):
scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
else: # channelwise
scales = scales.reshape((-1, len(scale_perm_single_2_4)))[
:, scale_perm_single_2_4
]
scales = scales.reshape((-1, size_n)).contiguous()

return scales
6 changes: 3 additions & 3 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from transformers.file_utils import CONFIG_NAME


__all__ = ["ModelCompressor"]
__all__ = ["ModelCompressor", "map_modules_to_quant_args"]

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -190,7 +190,7 @@ def compress(
state_dict = model.state_dict()

compressed_state_dict = state_dict
quantized_modules_to_args = _get_weight_arg_mappings(model)
quantized_modules_to_args = map_modules_to_quant_args(model)
if self.quantization_compressor is not None:
compressed_state_dict = self.quantization_compressor.compress(
state_dict, model_quant_args=quantized_modules_to_args
Expand Down Expand Up @@ -269,7 +269,7 @@ def _replace_weights(self, dense_weight_generator, model):
data_old.data = data_new.data


def _get_weight_arg_mappings(model: Module) -> Dict:
def map_modules_to_quant_args(model: Module) -> Dict:
quantized_modules_to_args = {}
for name, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
Expand Down
19 changes: 19 additions & 0 deletions src/compressed_tensors/compressors/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .helpers import *
from .permutations_24 import *
from .semi_structured_conversions import *
43 changes: 43 additions & 0 deletions src/compressed_tensors/compressors/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


__all__ = ["tensor_follows_mask_structure"]


def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
"""
:param tensor: tensor to check
:param mask: mask structure to check for, in the format "n:m"
:return: True if the tensor follows the mask structure, False otherwise.
Note, some weights can incidentally be zero, so we check for
atleast n zeros in each chunk of size m
"""

n, m = tuple(map(int, mask.split(":")))
# Reshape the tensor into chunks of size m
tensor = tensor.view(-1, m)

# Count the number of zeros in each chunk
zero_counts = (tensor == 0).sum(dim=1)

# Check if the number of zeros in each chunk atleast n
# Greater than sign is needed as some weights can incidentally
# be zero
if not torch.all(zero_counts >= n).item():
raise ValueError()

return True
Loading

0 comments on commit 14b1db1

Please sign in to comment.