Skip to content

[compressor] Add packed int8 support #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor
raise NotImplementedError()

def decompress(
self, path_to_model_or_tensors: str, device: str = "cpu"
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at path_to_model_or_tensors
Expand Down
8 changes: 4 additions & 4 deletions src/compressed_tensors/compressors/marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,19 @@ def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
def compress(
self,
model_state: Dict[str, Tensor],
model_quant_args: Dict[str, QuantizationArgs],
names_to_scheme: Dict[str, QuantizationArgs],
**kwargs,
) -> Dict[str, Tensor]:
"""
Compresses a quantized state_dict with 2:4 sparsity structure for inference
with the Marlin24 kernel

:param model_state: state dict of uncompressed model
:param model_quant_args: quantization args for each quantized weight, needed for
:param names_to_scheme: quantization args for each quantized weight, needed for
quantize function to calculate bit depth
:return: compressed state dict
"""
self.validate_quant_compatability(model_quant_args)
self.validate_quant_compatability(names_to_scheme)

compressed_dict = {}
weight_suffix = ".weight"
Expand All @@ -139,7 +139,7 @@ def compress(
value = value.to(torch.float16)

# quantize weight, keeping it as a float16 for now
quant_args = model_quant_args[prefix]
quant_args = names_to_scheme[prefix]
value = quantize(
x=value, scale=scale, zero_point=zp, args=quant_args
)
Expand Down
8 changes: 5 additions & 3 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def compress(
quantized_modules_to_args = map_modules_to_quant_args(model)
if self.quantization_compressor is not None:
compressed_state_dict = self.quantization_compressor.compress(
state_dict, model_quant_args=quantized_modules_to_args
state_dict, names_to_scheme=quantized_modules_to_args
)

if self.sparsity_compressor is not None:
Expand Down Expand Up @@ -260,9 +260,11 @@ def decompress(self, model_path: str, model: Module):
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)

if self.quantization_compressor is not None:
apply_quantization_config(model, self.quantization_config)
names_to_scheme = apply_quantization_config(model, self.quantization_config)
load_pretrained_quantization(model, model_path)
dense_gen = self.quantization_compressor.decompress(model_path)
dense_gen = self.quantization_compressor.decompress(
model_path, names_to_scheme=names_to_scheme
)
self._replace_weights(dense_gen, model)

def update_status(module):
Expand Down
6 changes: 3 additions & 3 deletions src/compressed_tensors/compressors/naive_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ class QuantizationCompressor(Compressor):
def compress(
self,
model_state: Dict[str, Tensor],
model_quant_args: Dict[str, QuantizationArgs],
names_to_scheme: Dict[str, QuantizationArgs],
**kwargs,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict

:param model_state: state dict of uncompressed model
:param model_quant_args: quantization args for each quantized weight, needed for
:param names_to_scheme: quantization args for each quantized weight, needed for
quantize function to calculate bit depth
:return: compressed state dict
"""
Expand All @@ -73,7 +73,7 @@ def compress(
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
if scale is not None and zp is not None:
# weight is quantized, compress it
quant_args = model_quant_args[prefix]
quant_args = names_to_scheme[prefix]
if can_quantize(value, quant_args):
# only quantize if not already quantized
value = quantize(
Expand Down
104 changes: 84 additions & 20 deletions src/compressed_tensors/compressors/pack_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
from tqdm import tqdm


__all__ = ["PackedQuantizationCompressor", "pack_4bit_ints", "unpack_4bit_ints"]
__all__ = [
"PackedQuantizationCompressor",
"pack_4bit_ints",
"pack_8bit_ints",
"unpack_4bit_ints",
"unpack_8bit_ints",
]

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand All @@ -50,14 +56,14 @@ class PackedQuantizationCompressor(Compressor):
def compress(
self,
model_state: Dict[str, Tensor],
model_quant_args: Dict[str, QuantizationArgs],
names_to_scheme: Dict[str, QuantizationArgs],
**kwargs,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict

:param model_state: state dict of uncompressed model
:param model_quant_args: quantization args for each quantized weight, needed for
:param names_to_scheme: quantization args for each quantized weight, needed for
quantize function to calculate bit depth
:return: compressed state dict
"""
Expand All @@ -75,7 +81,7 @@ def compress(
shape = torch.tensor(value.shape)
if scale is not None and zp is not None:
# weight is quantized, compress it
quant_args = model_quant_args[prefix]
quant_args = names_to_scheme[prefix]
if can_quantize(value, quant_args):
# convert weight to an int if not already compressed
value = quantize(
Expand All @@ -85,7 +91,11 @@ def compress(
args=quant_args,
dtype=torch.int8,
)
value = pack_4bit_ints(value.cpu())

if quant_args.num_bits == 8:
value = pack_8bit_ints(value.cpu())
else:
value = pack_4bit_ints(value.cpu())
compressed_dict[merge_names(prefix, "weight_shape")] = shape
compressed_dict[merge_names(prefix, "weight_packed")] = value
continue
Expand All @@ -101,7 +111,10 @@ def compress(
return compressed_dict

def decompress(
self, path_to_model_or_tensors: str, device: str = "cpu"
self,
path_to_model_or_tensors: str,
names_to_scheme: Dict[str, QuantizationArgs],
device: str = "cpu",
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at path_to_model_or_tensors
Expand All @@ -119,6 +132,7 @@ def decompress(
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
Expand All @@ -127,8 +141,12 @@ def decompress(
zero_point = weight_data.get("weight_zero_point", None)
scale = weight_data["weight_scale"]
weight = weight_data["weight_packed"]
num_bits = weight_data["num_bits"]
original_shape = torch.Size(weight_data["weight_shape"])
unpacked = unpack_4bit_ints(weight, original_shape)
if num_bits == 4:
unpacked = unpack_4bit_ints(weight, original_shape)
else:
unpacked = unpack_8bit_ints(weight, original_shape)
decompressed = dequantize(
x_q=unpacked,
scale=scale,
Expand All @@ -137,6 +155,19 @@ def decompress(
yield merge_names(weight_name, "weight"), decompressed


def pack_8bit_ints(value: torch.Tensor) -> torch.Tensor:
"""
Packs a tensor of int8 into int32s with padding

:param value: tensor to pack
:returns: packed int32 tensor
"""
# need to convert to unsigned 8bit to use numpy's pack/unpack
value_uint = (value - 128).to(torch.uint8)
bits = np.unpackbits(value_uint, axis=-1, bitorder="little")
return _pack_bits(bits_to_pack=bits)


def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor:
"""
Packs a tensor of int4 weights stored in int8 into int32s with padding
Expand All @@ -152,22 +183,31 @@ def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor:
bits = np.unpackbits(temp.numpy(), axis=-1, bitorder="little")
ranges = np.array([range(x, x + 4) for x in range(0, bits.shape[1], 8)]).flatten()
only_4_bits = bits[:, ranges] # top 4 bits are 0 because we're really uint4
return _pack_bits(bits_to_pack=only_4_bits)

# pad each row to fill a full 32bit int
pack_depth = 32
padding = (
math.ceil(only_4_bits.shape[1] / pack_depth) * pack_depth - only_4_bits.shape[1]
)
padded_bits = np.pad(
only_4_bits, pad_width=[(0, 0), (0, padding)], constant_values=0
)

# after packbits each uint8 is two packed uint4s
# then we keep the bit pattern the same but convert to int32
compressed = np.packbits(padded_bits, axis=-1, bitorder="little")
compressed = np.ascontiguousarray(compressed).view(np.int32)
def unpack_8bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
"""
Unpacks a tensor packed int8 weights in int32

return torch.from_numpy(compressed)
:param value: tensor to upack
:param shape: shape to unpack into, used to remove padding
:returns: unpacked int8 tensor
"""
if value.dtype is not torch.int32:
raise ValueError(
f"Expected {torch.int32} but got {value.dtype}, Aborting unpack."
)

# unpack bits and undo padding to nearest int32 bits
individual_depth = 8
as_uint8 = value.numpy().view(np.uint8)
bits = np.unpackbits(as_uint8, axis=-1, bitorder="little")
original_row_size = int(shape[1] * individual_depth)
bits = bits[:, :original_row_size]
bits = np.packbits(bits, axis=-1, bitorder="little")
final = (bits - 128).astype(np.int8)
return torch.from_numpy(final)


def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
Expand Down Expand Up @@ -206,3 +246,27 @@ def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
final = repacked.astype(np.int8) - 8

return torch.from_numpy(final)


def _pack_bits(bits_to_pack: torch.Tensor) -> torch.Tensor:
"""
Pack a tensor of bits to int32.

:param bits_to_pack: tensor of bits to pack
"""
# pad each row to fill a full 32bit int
pack_depth = 32
padding = (
math.ceil(bits_to_pack.shape[1] / pack_depth) * pack_depth
- bits_to_pack.shape[1]
)
padded_bits = np.pad(
bits_to_pack, pad_width=[(0, 0), (0, padding)], constant_values=0
)

# after packbits each uint8 is two packed uint4s
# then we keep the bit pattern the same but convert to int32
compressed = np.packbits(padded_bits, axis=-1, bitorder="little")
compressed = np.ascontiguousarray(compressed).view(np.int32)

return torch.from_numpy(compressed)
6 changes: 5 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
)


def apply_quantization_config(model: Module, config: QuantizationConfig):
def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict:
"""
Initializes the model for quantization in-place based on the given config

Expand All @@ -106,6 +106,7 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
# build mapping of targets to schemes for easier matching
# use ordered dict to preserve target ordering in config
target_to_scheme = OrderedDict()
names_to_scheme = OrderedDict()
for scheme in config.config_groups.values():
for target in scheme.targets:
target_to_scheme[target] = scheme
Expand All @@ -123,6 +124,7 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
if target is not None:
# target matched - add layer and scheme to target list
submodule.quantization_scheme = target_to_scheme[target]
names_to_scheme[name] = submodule.quantization_scheme.weights

if config.ignore is not None and ignored_submodules is not None:
if set(config.ignore) - set(ignored_submodules):
Expand All @@ -132,7 +134,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
f"{set(config.ignore) - set(ignored_submodules)}"
)
# apply current quantization status across all targeted layers

apply_quantization_status(model, config.quantization_status)
return names_to_scheme


def apply_quantization_status(model: Module, status: QuantizationStatus):
Expand Down
1 change: 0 additions & 1 deletion src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional

from transformers import AutoConfig
Expand Down
4 changes: 2 additions & 2 deletions tests/test_compressors/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_quant_format(strategy, group_size, sc, zp):
compressor = FloatQuantizationCompressor(config=quant_config)
quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}
compressed_state_dict = compressor.compress(
dense_state_dict, model_quant_args=quantized_modules_to_args
dense_state_dict, names_to_scheme=quantized_modules_to_args
)

# state_dict params should be the same, minus the zero_point if symmetric
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_reload_match(strategy, group_size, tmp_path):
"dummy": quant_config.config_groups["group_1"].weights,
}
compressed_state_dict = compressor.compress(
model.state_dict(), model_quant_args=quantized_modules_to_args
model.state_dict(), names_to_scheme=quantized_modules_to_args
)
save_file(compressed_state_dict, tmp_path / "model.safetensors")
reconstructed_dense_gen = compressor.decompress(tmp_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_compressors/test_int_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
compressor = IntQuantizationCompressor(config=quant_config)
quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}
compressed_state_dict = compressor.compress(
dense_state_dict, model_quant_args=quantized_modules_to_args
dense_state_dict, names_to_scheme=quantized_modules_to_args
)

# state_dict params should be the same, minus the zero_point if symmetric
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
"dummy2": quant_config.config_groups["group_1"].weights,
}
compressed_state_dict = compressor.compress(
dense_state_dict, model_quant_args=quantized_modules_to_args
dense_state_dict, names_to_scheme=quantized_modules_to_args
)
save_file(compressed_state_dict, tmp_path / "model.safetensors")
reconstructed_dense_gen = compressor.decompress(tmp_path)
Expand Down
Loading
Loading