Skip to content

Commit

Permalink
Update folder structure
Browse files Browse the repository at this point in the history
Move tests
  • Loading branch information
rahul-tuli committed Oct 2, 2024
1 parent 27011f6 commit a460b70
Show file tree
Hide file tree
Showing 25 changed files with 401 additions and 29 deletions.
18 changes: 6 additions & 12 deletions src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,9 @@

# flake8: noqa

from .base import BaseCompressor
from .dense import DenseCompressor
from .helpers import load_compressed, save_compressed, save_compressed_model
from .marlin_24 import Marlin24Compressor
from .model_compressor import ModelCompressor, map_modules_to_quant_args
from .naive_quantized import (
FloatQuantizationCompressor,
IntQuantizationCompressor,
QuantizationCompressor,
)
from .pack_quantized import PackedQuantizationCompressor
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
from .base import *
from .helpers import *
from .model_compressors import *
from .quantized_compressors import *
from .sparse_compressors import *
from .sparse_quantized_compressors import *
6 changes: 3 additions & 3 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ class BaseCompressor(RegistryMixin, ABC):
Model Load Lifecycle (run_compressed=False):
- ModelCompressor.decompress()
- apply_quantization_config()
- Compressor.decompress()
- BaseCompressor.decompress()
Model Save Lifecycle:
- ModelCompressor.compress()
- Compressor.compress()
- BaseCompressor.compress()
Module Lifecycle (run_compressed=True):
- apply_quantization_config()
- compressed_module = CompressedLinear(module)
- initialize_module_for_quantization()
- Compressor.compression_param_info()
- BaseCompressor.compression_param_info()
- register_parameters()
- compressed_module.forward()
-compressed_module.decompress()
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/compressors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Dict, Generator, Optional, Tuple, Union

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors import BaseCompressor
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.utils.safetensors_load import get_weight_mappings
from safetensors import safe_open
Expand Down
17 changes: 17 additions & 0 deletions src/compressed_tensors/compressors/model_compressors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa


from .model_compressor import *
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
QUANTIZATION_METHOD_NAME,
SPARSITY_CONFIG_NAME,
)
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization import (
DEFAULT_QUANTIZATION_METHOD,
Expand Down Expand Up @@ -247,11 +247,11 @@ def __init__(
self.sparsity_config = None

if sparsity_config is not None:
self.sparsity_compressor = BaseCompressor.load_from_registry(
self.sparsity_compressor = Compressor.load_from_registry(
sparsity_config.format, config=sparsity_config
)
if quantization_config is not None:
self.quantization_compressor = BaseCompressor.load_from_registry(
self.quantization_compressor = Compressor.load_from_registry(
quantization_config.format, config=quantization_config
)

Expand All @@ -262,7 +262,7 @@ def compress(
Compresses a dense state dict or model with sparsity and/or quantization
:param model: uncompressed model to compress
:param state_dict: optional uncompressed state_dict to insert into model
:param model_state: optional uncompressed state_dict to insert into model
:return: compressed state dict
"""
if state_dict is None:
Expand Down Expand Up @@ -393,4 +393,4 @@ def new_dtype_byte_size(dtype):
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
return bit_size // 8
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa

from .base import *
from .naive_quantized import *
from .pack_quantized import *
146 changes: 146 additions & 0 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationArgs
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm


_LOGGER: logging.Logger = logging.getLogger(__name__)

__all__ = ["BaseQuantizationCompressor"]


class BaseQuantizationCompressor(BaseCompressor):
"""
Base class representing a quant compression algorithm. Each child class should
implement compression_param_info, compress_weight and decompress_weight.
Compressors support compressing/decompressing a full module state dict or a single
quantized PyTorch leaf module.
Model Load Lifecycle (run_compressed=False):
- ModelCompressor.decompress()
- apply_quantization_config()
- BaseQuantiazationCompressor.decompress()
- BaseQuantizationCompressor.decompress_weight()
Model Save Lifecycle:
- ModelCompressor.compress()
- BaseQuantizationCompressor.compress()
- BaseQuantizationCompressor.compress_weight()
Module Lifecycle (run_compressed=True):
- apply_quantization_config()
- compressed_module = CompressedLinear(module)
- initialize_module_for_quantization()
- BaseQuantizationCompressor.compression_param_info()
- register_parameters()
- compressed_module.forward()
-compressed_module.decompress()
:param config: config specifying compression parameters
"""

def compress(
self,
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationArgs],
**kwargs,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict
:param model_state: state dict of uncompressed model
:param names_to_scheme: quantization args for each quantized weight, needed for
quantize function to calculate bit depth
:return: compressed state dict
"""
compressed_dict = {}
weight_suffix = ".weight"
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)

for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
if name.endswith(weight_suffix):
prefix = name[: -(len(weight_suffix))]
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
if scale is not None:
# weight is quantized, compress it
quant_args = names_to_scheme[prefix]
compressed_data = self.compress_weight(
weight=value,
scale=scale,
zero_point=zp,
g_idx=g_idx,
quantization_args=quant_args,
device="cpu",
)
for key, value in compressed_data.items():
compressed_dict[merge_names(prefix, key)] = value
else:
compressed_dict[name] = value.to("cpu")
elif name.endswith("zero_point") and torch.all(value == 0):
continue
elif name.endswith("g_idx") and torch.any(value <= -1):
continue
else:
compressed_dict[name] = value.to("cpu")

return compressed_dict

def decompress(
self,
path_to_model_or_tensors: str,
names_to_scheme: Dict[str, QuantizationArgs],
device: str = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at path_to_model_or_tensors
and returns a generator for sequentially decompressing back to a
dense state dict
:param path_to_model_or_tensors: path to compressed safetensors model (directory
with one or more safetensors files) or compressed tensors file
:param names_to_scheme: quantization args for each quantized weight
:param device: optional device to load intermediate weights into
:return: compressed state dict
"""
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)

if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
decompressed = self.decompress_weight(
compressed_data=weight_data, quantization_args=quant_args
)
yield merge_names(weight_name, "weight"), decompressed
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.base_quantization_compressor import (
from compressed_tensors.compressors.quantized_compressors.base import (
BaseQuantizationCompressor,
)
from compressed_tensors.config import CompressionFormat
Expand All @@ -27,14 +27,14 @@


__all__ = [
"QuantizationCompressor",
"NaiveQuantizationCompressor",
"IntQuantizationCompressor",
"FloatQuantizationCompressor",
]


@BaseCompressor.register(name=CompressionFormat.naive_quantized.value)
class QuantizationCompressor(BaseQuantizationCompressor):
class NaiveQuantizationCompressor(BaseQuantizationCompressor):
"""
Implements naive compression for quantized models. Weight of each
quantized layer is converted from its original float type to the closest Pytorch
Expand Down Expand Up @@ -123,7 +123,7 @@ def decompress_weight(


@BaseCompressor.register(name=CompressionFormat.int_quantized.value)
class IntQuantizationCompressor(QuantizationCompressor):
class IntQuantizationCompressor(NaiveQuantizationCompressor):
"""
Alias for integer quantized models
"""
Expand All @@ -132,7 +132,7 @@ class IntQuantizationCompressor(QuantizationCompressor):


@BaseCompressor.register(name=CompressionFormat.float_quantized.value)
class FloatQuantizationCompressor(QuantizationCompressor):
class FloatQuantizationCompressor(NaiveQuantizationCompressor):
"""
Alias for fp quantized models
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.base_quantization_compressor import (
from compressed_tensors.compressors.quantized_compressors.base import (
BaseQuantizationCompressor,
)
from compressed_tensors.config import CompressionFormat
Expand Down
18 changes: 18 additions & 0 deletions src/compressed_tensors/compressors/sparse_compressors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa

from .base import *
from .dense import *
from .sparse_bitmask import *
Loading

0 comments on commit a460b70

Please sign in to comment.