Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compressed lifecycle implementation (INT8 only) #33

Merged
merged 17 commits into from
May 7, 2024
Merged
3 changes: 2 additions & 1 deletion src/compressed_tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

SPARSITY_CONFIG_NAME = "sparsity_config"
QUANTIZATION_CONFIG_NAME = "sparseml_quantization_config"
QUANTIZATION_CONFIG_NAME = "quantization_config"
COMPRESSION_CONFIG_NAME = "compression_config"
4 changes: 3 additions & 1 deletion src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

# flake8: noqa

from .base import ModelCompressor
from .base import Compressor
from .dense import DenseCompressor
from .helpers import load_compressed, save_compressed, save_compressed_model
from .int_quantized import IntQuantizationCompressor
from .model_compressor import ModelCompressor
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
64 changes: 10 additions & 54 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
from typing import Dict, Generator, Optional, Tuple
from typing import Dict, Generator, Tuple, Union

from compressed_tensors.base import SPARSITY_CONFIG_NAME
from compressed_tensors.config import CompressionConfig
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationConfig
from compressed_tensors.registry import RegistryMixin
from compressed_tensors.utils import get_safetensors_folder
from torch import Tensor
from torch.nn import Module, Parameter
from tqdm import tqdm
from transformers import AutoConfig


__all__ = ["ModelCompressor"]
__all__ = ["Compressor"]


class ModelCompressor(RegistryMixin):
class Compressor(RegistryMixin):
"""
Base class representing a model compression algorithm.
Base class representing a model compression algorithm

:param config: config specifying compression parameters
"""

@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str
) -> Optional["ModelCompressor"]:
"""
Given a path to a model config, extract a sparsity config if it exists and
return the associated ModelCompressor

:param pretrained_model_name_or_path: path to model config on disk or HF hub
:return: matching compressor if config contains a sparsity config
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
if sparsity_config is None:
return None

format = sparsity_config.get("format")
sparsity_config = CompressionConfig.load_from_registry(
format, **sparsity_config
)
compressor = cls.load_from_registry(format, config=sparsity_config)
return compressor

def __init__(self, config: Optional[CompressionConfig] = None):
def __init__(
self, config: Union[SparsityCompressionConfig, QuantizationConfig, None] = None
):
self.config = config

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
"""
Compresses a dense state dict

Expand All @@ -83,21 +57,3 @@ def decompress(
:return: compressed state dict
"""
raise NotImplementedError()

def overwrite_weights(self, model_path: str, model: Module):
"""
Overwrites the weights in model with weights decompressed from model_path

:param model_path: path to compressed weights
:param model: pytorch model to load decompressed weights into
"""
model_path = get_safetensors_folder(model_path)
dense_gen = self.decompress(model_path)
for name, data in tqdm(dense_gen, desc="Decompressing model"):
# loading the decompressed weights into the model
model_device = operator.attrgetter(name)(model).device
data_new = Parameter(data.to(model_device))
data_old = operator.attrgetter(name)(model)
data_old.data = data_new.data

setattr(model, SPARSITY_CONFIG_NAME, self.config)
8 changes: 4 additions & 4 deletions src/compressed_tensors/compressors/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

from typing import Dict, Generator, Tuple

from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat
from torch import Tensor


@ModelCompressor.register(name=CompressionFormat.dense_sparsity.value)
class DenseCompressor(ModelCompressor):
@Compressor.register(name=CompressionFormat.dense.value)
class DenseCompressor(Compressor):
"""
Identity compressor for dense models, returns the original state_dict
"""

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
return model_state

def decompress(
Expand Down
24 changes: 12 additions & 12 deletions src/compressed_tensors/compressors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from typing import Dict, Generator, Optional, Tuple, Union

import torch
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import CompressionConfig, CompressionFormat
from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.utils.safetensors_load import get_weight_mappings
from safetensors import safe_open
from safetensors.torch import save_file
Expand Down Expand Up @@ -48,28 +48,28 @@ def save_compressed(
if tensors is None or len(tensors) == 0:
raise ValueError("No tensors or empty tensors provided to compress")

# if no compression_format specified, default to `dense_sparsity`
compression_format = compression_format or CompressionFormat.dense_sparsity.value
# if no compression_format specified, default to `dense`
compression_format = compression_format or CompressionFormat.dense.value

if not (
compression_format in ModelCompressor.registered_names()
or compression_format in ModelCompressor.registered_aliases()
compression_format in Compressor.registered_names()
or compression_format in Compressor.registered_aliases()
):
raise ValueError(
f"Unknown compression format: {compression_format}. "
f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501
f"Must be one of {set(Compressor.registered_names() + Compressor.registered_aliases())}" # noqa E501
)

# compress
compressor = ModelCompressor.load_from_registry(compression_format)
compressor = Compressor.load_from_registry(compression_format)
# save compressed tensors
compressed_tensors = compressor.compress(tensors)
save_file(compressed_tensors, save_path)


def load_compressed(
compressed_tensors: Union[str, Path],
compression_config: CompressionConfig = None,
compression_config: SparsityCompressionConfig = None,
device: Optional[str] = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Expand All @@ -90,9 +90,9 @@ def load_compressed(

if (
compression_config is None
or compression_config.format == CompressionFormat.dense_sparsity.value
or compression_config.format == CompressionFormat.dense.value
):
# if no compression_config specified, or `dense_sparsity` format specified,
# if no compression_config specified, or `dense` format specified,
# assume tensors are not compressed on disk
weight_mappings = get_weight_mappings(compressed_tensors)
for weight_name, file_with_weight_name in weight_mappings.items():
Expand All @@ -102,7 +102,7 @@ def load_compressed(
else:
# decompress tensors
compression_format = compression_config.format
compressor = ModelCompressor.load_from_registry(
compressor = Compressor.load_from_registry(
compression_format, config=compression_config
)
yield from compressor.decompress(compressed_tensors, device=device)
Expand Down
95 changes: 95 additions & 0 deletions src/compressed_tensors/compressors/int_quantized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple

import torch
from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm


__all__ = ["IntQuantizationCompressor"]

_LOGGER: logging.Logger = logging.getLogger(__name__)


@Compressor.register(name=CompressionFormat.int_quantized.value)
class IntQuantizationCompressor(Compressor):
"""
Integer compression for quantized models. Weight of each quantized layer is
converted from its original float type to the format specified by the layer's
quantization scheme.
"""

COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]

def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
model_quant_args = kwargs["model_quant_args"]
compressed_dict = {}
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)

for name, value in tqdm(model_state.items(), desc="Compressing model"):
if name.endswith(".weight"):
prefix = name.removesuffix(".weight")
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
if scale is not None and zp is not None:
# weight is quantized, compress it
quant_args = model_quant_args[prefix]
try:
bit_depth = torch.finfo(value.dtype).bits
except TypeError:
bit_depth = torch.iinfo(value.dtype).bits
if bit_depth > quant_args.num_bits:
# only quantize if not already quantized
value = quantize(
x=value,
scale=scale,
zero_point=zp,
args=quant_args,
dtype=torch.int8,
)

compressed_dict[name] = value.to("cpu")

return compressed_dict

def decompress(
self, path_to_model_or_tensors: str, device: str = "cpu"
) -> Generator[Tuple[str, Tensor], None, None]:
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)

if len(weight_data) == len(self.COMPRESSION_PARAM_NAMES):
decompressed = dequantize(
x_q=weight_data["weight"],
scale=weight_data["weight_scale"],
zero_point=weight_data["weight_zero_point"],
)
yield merge_names(weight_name, "weight"), decompressed
Loading
Loading