Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marlin24 Compressor #77

Merged
merged 22 commits into from
Jun 11, 2024
3 changes: 2 additions & 1 deletion src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .dense import DenseCompressor
from .helpers import load_compressed, save_compressed, save_compressed_model
from .int_quantized import IntQuantizationCompressor
from .model_compressor import ModelCompressor
from .marlin_24 import Marlin24Compressor
from .model_compressor import ModelCompressor, map_modules_to_quant_args
from .pack_quantized import PackedQuantizationCompressor
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
267 changes: 267 additions & 0 deletions src/compressed_tensors/compressors/marlin_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
from typing import Dict, Generator, Tuple

import numpy as np
import torch
from compressed_tensors.compressors import Compressor
from compressed_tensors.compressors.utils import (
get_permutations_24,
sparse_semi_structured_from_dense_cutlass,
)
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import quantize
from compressed_tensors.utils import is_quantization_param, merge_names
from torch import Tensor
from tqdm import tqdm


_LOGGER: logging.Logger = logging.getLogger(__name__)


@Compressor.register(name=CompressionFormat.marlin_24.value)
class Marlin24Compressor(Compressor):
"""
Compresses a quantized model with 2:4 sparsity structure for inference with the
Marlin24 kernel. Decompression is not implemented for this compressor.
"""

COMPRESSION_PARAM_NAMES = ["weight_packed", "scale_packed", "meta"]

@staticmethod
def validate_quant_compatability(model_quant_args: Dict[str, QuantizationArgs]):
Satrat marked this conversation as resolved.
Show resolved Hide resolved
"""
Checks if every quantized module in the model is compatible with Marlin24
compression. Quantization must be channel or group strategy with group_size
of 128. Only symmetric quantization is supported

:param model_quant_args: dictionary of mapping module names to their
quantization configuration
:return: True if all modules are compatible with Marlin24 compression, raises
a ValueError otherwise
"""
for name, quant_args in model_quant_args.items():
strategy = quant_args.strategy
group_size = quant_args.group_size
symmetric = quant_args.symmetric
if (
strategy is not QuantizationStrategy.GROUP
and strategy is not QuantizationStrategy.CHANNEL
):
raise ValueError(
f"Marlin24 Compressor is only valid for group and channel "
f"quantization strategies, got {strategy} in {name}"
)

if group_size is not None and group_size != 128:
raise ValueError(
f"Marlin24 Compressor is only valid for group size 128, "
f"got {group_size} in {name}"
)

if not symmetric:
raise ValueError(
f"Marlin24 Compressor is only valid for symmetric quantzation, "
f"got symmetric={symmetric} in {name}"
)

return True

@staticmethod
def validate_sparsity_structure(
name: str, weight: Tensor, num_rows_to_sample: int = 20
) -> bool:
"""
Checks if a tensor fits the 2:4 sparsity structure by sampling a specified
number of rows.

:param name: name of the tensor to check
:param weight: tensor to check for sparsity structure
:param num_rows_to_sample: number of rows to check the sparsity structure of
:return: True if all sampled rows match the 2:4 sparsity structure, raises
ValueError otherwise
"""
BLOCK_SIZE = 4
Satrat marked this conversation as resolved.
Show resolved Hide resolved
MAX_NON_ZEROS = 2

weight = weight.contiguous()
Satrat marked this conversation as resolved.
Show resolved Hide resolved

num_rows, num_cols = weight.shape
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)

non_24_segments = 0
for i in sampled_row_idxs:
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
block = weight[i, j : j + BLOCK_SIZE]
num_nonzero = torch.count_nonzero(block)
if num_nonzero > MAX_NON_ZEROS:
non_24_segments += 1

if non_24_segments > 0:
raise ValueError(
"Marlin24 Compressor is only compatible with weights that have "
f"a 2:4 sparsity structure. Found {non_24_segments} segments in {name} "
"that do not match the expected structure."
)

return True

def compress(
self,
model_state: Dict[str, Tensor],
model_quant_args: Dict[str, QuantizationArgs],
**kwargs,
) -> Dict[str, Tensor]:
"""
Compresses a quantized state_dict with 2:4 sparsity structure for inference
with the Marlin24 kernel

:param model_state: state dict of uncompressed model
:param model_quant_args: quantization args for each quantized weight, needed for
quantize function to calculate bit depth
Satrat marked this conversation as resolved.
Show resolved Hide resolved
:return: compressed state dict
"""
self.validate_quant_compatability(model_quant_args)

compressed_dict = {}
weight_suffix = ".weight"
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)

for name, value in tqdm(model_state.items(), desc="Compressing model"):
if name.endswith(weight_suffix):
prefix = name[: -(len(weight_suffix))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prefix = name[: -(len(weight_suffix))]
prefix = name.replace(weight_suffix)

More readable :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but the replace code runs the risk of replacing ".weight" if it exists elsewhere in the string, not likely but possible

scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
if scale is not None: # weight is quantized, compress it

# Marlin24 kernel requires float16 inputs
scale = scale.to(torch.float16)
value = value.to(torch.float16)

# quantize weight, keeping it as a float16 for now
quant_args = model_quant_args[prefix]
value = quantize(
x=value, scale=scale, zero_point=zp, args=quant_args
)

# compress based on sparsity structure
self.validate_sparsity_structure(prefix, value)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
value, meta = compress_weight_24(value)
meta = meta.cpu()

# Marlin24 kernel expects input dim first
value = value.t().contiguous().cpu()
scale = scale.t().contiguous().cpu()
og_weight_shape = value.shape

# Marlin24 kernel expects unsigned values, shift zero-point
value += (1 << quant_args.num_bits) // 2

# pack quantized weight and scale
value = pack_weight_24(value, quant_args)
packed_scale = pack_scales_24(scale, quant_args, og_weight_shape)
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)

# save compressed values
compressed_dict[merge_names(prefix, "scale_packed")] = packed_scale
compressed_dict[merge_names(prefix, "weight_packed")] = value
compressed_dict[merge_names(prefix, "meta")] = meta
continue

if not is_quantization_param(name):
# export unquantized parameters without modifying
compressed_dict[name] = value.to("cpu")

return compressed_dict

def decompress(
self, path_to_model_or_tensors: str, device: str = "cpu"
) -> Generator[Tuple[str, Tensor], None, None]:
raise NotImplementedError(
"Decompression is not implemented for the Marlin24 Compressor."
)


def compress_weight_24(weight: Tensor):
weight = weight.contiguous()
w_comp, meta = sparse_semi_structured_from_dense_cutlass(weight)
w_comp = w_comp.contiguous()
return w_comp, meta


def marlin_permute_weights(q_w, size_k, size_n, perm, tile):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"

# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
Satrat marked this conversation as resolved.
Show resolved Hide resolved
q_w = q_w.reshape((size_k // tile, size_n * tile))

q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)

return q_w


def pack_weight_24(
weight: Tensor,
quantization_args: QuantizationArgs,
tile: int = 16,
):
size_k = weight.shape[0]
size_n = weight.shape[1]
num_bits = quantization_args.num_bits
pack_factor = 32 // num_bits

# Reshuffle to marlin_24 format
perm, _, _ = get_permutations_24(num_bits)
q_w = marlin_permute_weights(weight, size_k, size_n, perm, tile)

q_w = q_w.cpu().numpy().astype(np.uint32)

q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i

q_packed = torch.from_numpy(q_packed.astype(np.int32))

return q_packed


def pack_scales_24(scales, quantization_args, w_shape):
size_k = w_shape[0]
size_n = w_shape[1]
num_bits = quantization_args.num_bits

_, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)

if (
quantization_args.strategy is QuantizationStrategy.GROUP
and quantization_args.group_size < size_k
):
scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
else: # channelwise
scales = scales.reshape((-1, len(scale_perm_single_2_4)))[
:, scale_perm_single_2_4
]
scales = scales.reshape((-1, size_n)).contiguous()

return scales
6 changes: 3 additions & 3 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from transformers.file_utils import CONFIG_NAME


__all__ = ["ModelCompressor"]
__all__ = ["ModelCompressor", "map_modules_to_quant_args"]

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -190,7 +190,7 @@ def compress(
state_dict = model.state_dict()

compressed_state_dict = state_dict
quantized_modules_to_args = _get_weight_arg_mappings(model)
quantized_modules_to_args = map_modules_to_quant_args(model)
if self.quantization_compressor is not None:
compressed_state_dict = self.quantization_compressor.compress(
state_dict, model_quant_args=quantized_modules_to_args
Expand Down Expand Up @@ -269,7 +269,7 @@ def _replace_weights(self, dense_weight_generator, model):
data_old.data = data_new.data


def _get_weight_arg_mappings(model: Module) -> Dict:
def map_modules_to_quant_args(model: Module) -> Dict:
quantized_modules_to_args = {}
for name, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
Expand Down
18 changes: 18 additions & 0 deletions src/compressed_tensors/compressors/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .permutations_24 import *
from .semi_structured_conversions import *
65 changes: 65 additions & 0 deletions src/compressed_tensors/compressors/utils/permutations_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy
import torch


__all__ = ["get_permutations_24"]


# Precompute permutations for Marlin24 weight and scale shuffling
# Originally implemented in nm-vllm/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight
# data so that it is compatible with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data
# as it is needed for tensor-core (without the need to use ldmatrix instructions)
def get_permutations_24(num_bits):
perm_list = []
for i in range(32):
perm1 = []
col = i // 4
col_o = col // 2
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
for j in range(4):
perm_list.extend([p + 1 * j for p in perm1])
perm = numpy.array(perm_list)

if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))

perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
scale_perm_single = []
for i in range(8):
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
return perm, scale_perm, scale_perm_single
Loading
Loading