Skip to content

Commit

Permalink
Compressed lifecycle implementation (INT8 only) (#33)
Browse files Browse the repository at this point in the history
* Compressed lifecycle implementation (INT8 only)

* Apply suggestions from code review

* small fixes for runtime

* Quantization Compressor Support (#45)

* add classes

* WIP

* moving around classes

* code complete

* tests passing

* unit test bugs

* fill out int decompression

* docstrings

* allow repeat frozens

* int compressor unit tests

* PR comments

* fix device issue

* fixing leaf checker

* initial commit

* Revert "Merge branch 'main' into compressed-lifecycle"

This reverts commit 8dcdde5, reversing
changes made to bb36936.

* update version

* fix test

---------

Co-authored-by: Sara Adkins <sara@neuralmagic.com>
Co-authored-by: dbogunowicz <damian@neuralmagic.com>
  • Loading branch information
3 people authored May 7, 2024
1 parent f7e928b commit 964276d
Show file tree
Hide file tree
Showing 24 changed files with 737 additions and 125 deletions.
3 changes: 2 additions & 1 deletion src/compressed_tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

SPARSITY_CONFIG_NAME = "sparsity_config"
QUANTIZATION_CONFIG_NAME = "sparseml_quantization_config"
QUANTIZATION_CONFIG_NAME = "quantization_config"
COMPRESSION_CONFIG_NAME = "compression_config"
4 changes: 3 additions & 1 deletion src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

# flake8: noqa

from .base import ModelCompressor
from .base import Compressor
from .dense import DenseCompressor
from .helpers import load_compressed, save_compressed, save_compressed_model
from .int_quantized import IntQuantizationCompressor
from .model_compressor import ModelCompressor
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
64 changes: 10 additions & 54 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
from typing import Dict, Generator, Optional, Tuple
from typing import Dict, Generator, Tuple, Union

from compressed_tensors.base import SPARSITY_CONFIG_NAME
from compressed_tensors.config import CompressionConfig
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationConfig
from compressed_tensors.registry import RegistryMixin
from compressed_tensors.utils import get_safetensors_folder
from torch import Tensor
from torch.nn import Module, Parameter
from tqdm import tqdm
from transformers import AutoConfig


__all__ = ["ModelCompressor"]
__all__ = ["Compressor"]


class ModelCompressor(RegistryMixin):
class Compressor(RegistryMixin):
"""
Base class representing a model compression algorithm.
Base class representing a model compression algorithm
:param config: config specifying compression parameters
"""

@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str
) -> Optional["ModelCompressor"]:
"""
Given a path to a model config, extract a sparsity config if it exists and
return the associated ModelCompressor
:param pretrained_model_name_or_path: path to model config on disk or HF hub
:return: matching compressor if config contains a sparsity config
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
if sparsity_config is None:
return None

format = sparsity_config.get("format")
sparsity_config = CompressionConfig.load_from_registry(
format, **sparsity_config
)
compressor = cls.load_from_registry(format, config=sparsity_config)
return compressor

def __init__(self, config: Optional[CompressionConfig] = None):
def __init__(
self, config: Union[SparsityCompressionConfig, QuantizationConfig, None] = None
):
self.config = config

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
"""
Compresses a dense state dict
Expand All @@ -83,21 +57,3 @@ def decompress(
:return: compressed state dict
"""
raise NotImplementedError()

def overwrite_weights(self, model_path: str, model: Module):
"""
Overwrites the weights in model with weights decompressed from model_path
:param model_path: path to compressed weights
:param model: pytorch model to load decompressed weights into
"""
model_path = get_safetensors_folder(model_path)
dense_gen = self.decompress(model_path)
for name, data in tqdm(dense_gen, desc="Decompressing model"):
# loading the decompressed weights into the model
model_device = operator.attrgetter(name)(model).device
data_new = Parameter(data.to(model_device))
data_old = operator.attrgetter(name)(model)
data_old.data = data_new.data

setattr(model, SPARSITY_CONFIG_NAME, self.config)
8 changes: 4 additions & 4 deletions src/compressed_tensors/compressors/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

from typing import Dict, Generator, Tuple

from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat
from torch import Tensor


@ModelCompressor.register(name=CompressionFormat.dense_sparsity.value)
class DenseCompressor(ModelCompressor):
@Compressor.register(name=CompressionFormat.dense.value)
class DenseCompressor(Compressor):
"""
Identity compressor for dense models, returns the original state_dict
"""

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
return model_state

def decompress(
Expand Down
24 changes: 12 additions & 12 deletions src/compressed_tensors/compressors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from typing import Dict, Generator, Optional, Tuple, Union

import torch
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import CompressionConfig, CompressionFormat
from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.utils.safetensors_load import get_weight_mappings
from safetensors import safe_open
from safetensors.torch import save_file
Expand Down Expand Up @@ -48,28 +48,28 @@ def save_compressed(
if tensors is None or len(tensors) == 0:
raise ValueError("No tensors or empty tensors provided to compress")

# if no compression_format specified, default to `dense_sparsity`
compression_format = compression_format or CompressionFormat.dense_sparsity.value
# if no compression_format specified, default to `dense`
compression_format = compression_format or CompressionFormat.dense.value

if not (
compression_format in ModelCompressor.registered_names()
or compression_format in ModelCompressor.registered_aliases()
compression_format in Compressor.registered_names()
or compression_format in Compressor.registered_aliases()
):
raise ValueError(
f"Unknown compression format: {compression_format}. "
f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501
f"Must be one of {set(Compressor.registered_names() + Compressor.registered_aliases())}" # noqa E501
)

# compress
compressor = ModelCompressor.load_from_registry(compression_format)
compressor = Compressor.load_from_registry(compression_format)
# save compressed tensors
compressed_tensors = compressor.compress(tensors)
save_file(compressed_tensors, save_path)


def load_compressed(
compressed_tensors: Union[str, Path],
compression_config: CompressionConfig = None,
compression_config: SparsityCompressionConfig = None,
device: Optional[str] = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Expand All @@ -90,9 +90,9 @@ def load_compressed(

if (
compression_config is None
or compression_config.format == CompressionFormat.dense_sparsity.value
or compression_config.format == CompressionFormat.dense.value
):
# if no compression_config specified, or `dense_sparsity` format specified,
# if no compression_config specified, or `dense` format specified,
# assume tensors are not compressed on disk
weight_mappings = get_weight_mappings(compressed_tensors)
for weight_name, file_with_weight_name in weight_mappings.items():
Expand All @@ -102,7 +102,7 @@ def load_compressed(
else:
# decompress tensors
compression_format = compression_config.format
compressor = ModelCompressor.load_from_registry(
compressor = Compressor.load_from_registry(
compression_format, config=compression_config
)
yield from compressor.decompress(compressed_tensors, device=device)
Expand Down
95 changes: 95 additions & 0 deletions src/compressed_tensors/compressors/int_quantized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple

import torch
from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm


__all__ = ["IntQuantizationCompressor"]

_LOGGER: logging.Logger = logging.getLogger(__name__)


@Compressor.register(name=CompressionFormat.int_quantized.value)
class IntQuantizationCompressor(Compressor):
"""
Integer compression for quantized models. Weight of each quantized layer is
converted from its original float type to the format specified by the layer's
quantization scheme.
"""

COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]

def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
model_quant_args = kwargs["model_quant_args"]
compressed_dict = {}
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)

for name, value in tqdm(model_state.items(), desc="Compressing model"):
if name.endswith(".weight"):
prefix = name.removesuffix(".weight")
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
if scale is not None and zp is not None:
# weight is quantized, compress it
quant_args = model_quant_args[prefix]
try:
bit_depth = torch.finfo(value.dtype).bits
except TypeError:
bit_depth = torch.iinfo(value.dtype).bits
if bit_depth > quant_args.num_bits:
# only quantize if not already quantized
value = quantize(
x=value,
scale=scale,
zero_point=zp,
args=quant_args,
dtype=torch.int8,
)

compressed_dict[name] = value.to("cpu")

return compressed_dict

def decompress(
self, path_to_model_or_tensors: str, device: str = "cpu"
) -> Generator[Tuple[str, Tensor], None, None]:
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)

if len(weight_data) == len(self.COMPRESSION_PARAM_NAMES):
decompressed = dequantize(
x_q=weight_data["weight"],
scale=weight_data["weight_scale"],
zero_point=weight_data["weight_zero_point"],
)
yield merge_names(weight_name, "weight"), decompressed
Loading

0 comments on commit 964276d

Please sign in to comment.