diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index c4c2debca..4c8c97728 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -45,7 +45,7 @@ def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor raise NotImplementedError() def decompress( - self, path_to_model_or_tensors: str, device: str = "cpu" + self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs ) -> Generator[Tuple[str, Tensor], None, None]: """ Reads a compressed state dict located at path_to_model_or_tensors diff --git a/src/compressed_tensors/compressors/marlin_24.py b/src/compressed_tensors/compressors/marlin_24.py index 1c172e96b..1abf0a754 100644 --- a/src/compressed_tensors/compressors/marlin_24.py +++ b/src/compressed_tensors/compressors/marlin_24.py @@ -107,7 +107,7 @@ def validate_sparsity_structure(name: str, weight: Tensor) -> bool: def compress( self, model_state: Dict[str, Tensor], - model_quant_args: Dict[str, QuantizationArgs], + names_to_scheme: Dict[str, QuantizationArgs], **kwargs, ) -> Dict[str, Tensor]: """ @@ -115,11 +115,11 @@ def compress( with the Marlin24 kernel :param model_state: state dict of uncompressed model - :param model_quant_args: quantization args for each quantized weight, needed for + :param names_to_scheme: quantization args for each quantized weight, needed for quantize function to calculate bit depth :return: compressed state dict """ - self.validate_quant_compatability(model_quant_args) + self.validate_quant_compatability(names_to_scheme) compressed_dict = {} weight_suffix = ".weight" @@ -139,7 +139,7 @@ def compress( value = value.to(torch.float16) # quantize weight, keeping it as a float16 for now - quant_args = model_quant_args[prefix] + quant_args = names_to_scheme[prefix] value = quantize( x=value, scale=scale, zero_point=zp, args=quant_args ) diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index e57dbff6d..ac176291c 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -231,7 +231,7 @@ def compress( quantized_modules_to_args = map_modules_to_quant_args(model) if self.quantization_compressor is not None: compressed_state_dict = self.quantization_compressor.compress( - state_dict, model_quant_args=quantized_modules_to_args + state_dict, names_to_scheme=quantized_modules_to_args ) if self.sparsity_compressor is not None: @@ -260,9 +260,11 @@ def decompress(self, model_path: str, model: Module): setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config) if self.quantization_compressor is not None: - apply_quantization_config(model, self.quantization_config) + names_to_scheme = apply_quantization_config(model, self.quantization_config) load_pretrained_quantization(model, model_path) - dense_gen = self.quantization_compressor.decompress(model_path) + dense_gen = self.quantization_compressor.decompress( + model_path, names_to_scheme=names_to_scheme + ) self._replace_weights(dense_gen, model) def update_status(module): diff --git a/src/compressed_tensors/compressors/naive_quantized.py b/src/compressed_tensors/compressors/naive_quantized.py index e3a9a42fd..f54d78c49 100644 --- a/src/compressed_tensors/compressors/naive_quantized.py +++ b/src/compressed_tensors/compressors/naive_quantized.py @@ -49,14 +49,14 @@ class QuantizationCompressor(Compressor): def compress( self, model_state: Dict[str, Tensor], - model_quant_args: Dict[str, QuantizationArgs], + names_to_scheme: Dict[str, QuantizationArgs], **kwargs, ) -> Dict[str, Tensor]: """ Compresses a dense state dict :param model_state: state dict of uncompressed model - :param model_quant_args: quantization args for each quantized weight, needed for + :param names_to_scheme: quantization args for each quantized weight, needed for quantize function to calculate bit depth :return: compressed state dict """ @@ -73,7 +73,7 @@ def compress( zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) if scale is not None and zp is not None: # weight is quantized, compress it - quant_args = model_quant_args[prefix] + quant_args = names_to_scheme[prefix] if can_quantize(value, quant_args): # only quantize if not already quantized value = quantize( diff --git a/src/compressed_tensors/compressors/pack_quantized.py b/src/compressed_tensors/compressors/pack_quantized.py index 74b78132d..b8585fd54 100644 --- a/src/compressed_tensors/compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/pack_quantized.py @@ -29,7 +29,13 @@ from tqdm import tqdm -__all__ = ["PackedQuantizationCompressor", "pack_4bit_ints", "unpack_4bit_ints"] +__all__ = [ + "PackedQuantizationCompressor", + "pack_4bit_ints", + "pack_8bit_ints", + "unpack_4bit_ints", + "unpack_8bit_ints", +] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -50,14 +56,14 @@ class PackedQuantizationCompressor(Compressor): def compress( self, model_state: Dict[str, Tensor], - model_quant_args: Dict[str, QuantizationArgs], + names_to_scheme: Dict[str, QuantizationArgs], **kwargs, ) -> Dict[str, Tensor]: """ Compresses a dense state dict :param model_state: state dict of uncompressed model - :param model_quant_args: quantization args for each quantized weight, needed for + :param names_to_scheme: quantization args for each quantized weight, needed for quantize function to calculate bit depth :return: compressed state dict """ @@ -75,7 +81,7 @@ def compress( shape = torch.tensor(value.shape) if scale is not None and zp is not None: # weight is quantized, compress it - quant_args = model_quant_args[prefix] + quant_args = names_to_scheme[prefix] if can_quantize(value, quant_args): # convert weight to an int if not already compressed value = quantize( @@ -85,7 +91,11 @@ def compress( args=quant_args, dtype=torch.int8, ) - value = pack_4bit_ints(value.cpu()) + + if quant_args.num_bits == 8: + value = pack_8bit_ints(value.cpu()) + else: + value = pack_4bit_ints(value.cpu()) compressed_dict[merge_names(prefix, "weight_shape")] = shape compressed_dict[merge_names(prefix, "weight_packed")] = value continue @@ -101,7 +111,10 @@ def compress( return compressed_dict def decompress( - self, path_to_model_or_tensors: str, device: str = "cpu" + self, + path_to_model_or_tensors: str, + names_to_scheme: Dict[str, QuantizationArgs], + device: str = "cpu", ) -> Generator[Tuple[str, Tensor], None, None]: """ Reads a compressed state dict located at path_to_model_or_tensors @@ -119,6 +132,7 @@ def decompress( for weight_name in weight_mappings.keys(): weight_data = {} for param_name, safe_path in weight_mappings[weight_name].items(): + weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits full_name = merge_names(weight_name, param_name) with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) @@ -127,8 +141,12 @@ def decompress( zero_point = weight_data.get("weight_zero_point", None) scale = weight_data["weight_scale"] weight = weight_data["weight_packed"] + num_bits = weight_data["num_bits"] original_shape = torch.Size(weight_data["weight_shape"]) - unpacked = unpack_4bit_ints(weight, original_shape) + if num_bits == 4: + unpacked = unpack_4bit_ints(weight, original_shape) + else: + unpacked = unpack_8bit_ints(weight, original_shape) decompressed = dequantize( x_q=unpacked, scale=scale, @@ -137,6 +155,19 @@ def decompress( yield merge_names(weight_name, "weight"), decompressed +def pack_8bit_ints(value: torch.Tensor) -> torch.Tensor: + """ + Packs a tensor of int8 into int32s with padding + + :param value: tensor to pack + :returns: packed int32 tensor + """ + # need to convert to unsigned 8bit to use numpy's pack/unpack + value_uint = (value - 128).to(torch.uint8) + bits = np.unpackbits(value_uint, axis=-1, bitorder="little") + return _pack_bits(bits_to_pack=bits) + + def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor: """ Packs a tensor of int4 weights stored in int8 into int32s with padding @@ -152,22 +183,31 @@ def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor: bits = np.unpackbits(temp.numpy(), axis=-1, bitorder="little") ranges = np.array([range(x, x + 4) for x in range(0, bits.shape[1], 8)]).flatten() only_4_bits = bits[:, ranges] # top 4 bits are 0 because we're really uint4 + return _pack_bits(bits_to_pack=only_4_bits) - # pad each row to fill a full 32bit int - pack_depth = 32 - padding = ( - math.ceil(only_4_bits.shape[1] / pack_depth) * pack_depth - only_4_bits.shape[1] - ) - padded_bits = np.pad( - only_4_bits, pad_width=[(0, 0), (0, padding)], constant_values=0 - ) - # after packbits each uint8 is two packed uint4s - # then we keep the bit pattern the same but convert to int32 - compressed = np.packbits(padded_bits, axis=-1, bitorder="little") - compressed = np.ascontiguousarray(compressed).view(np.int32) +def unpack_8bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor: + """ + Unpacks a tensor packed int8 weights in int32 - return torch.from_numpy(compressed) + :param value: tensor to upack + :param shape: shape to unpack into, used to remove padding + :returns: unpacked int8 tensor + """ + if value.dtype is not torch.int32: + raise ValueError( + f"Expected {torch.int32} but got {value.dtype}, Aborting unpack." + ) + + # unpack bits and undo padding to nearest int32 bits + individual_depth = 8 + as_uint8 = value.numpy().view(np.uint8) + bits = np.unpackbits(as_uint8, axis=-1, bitorder="little") + original_row_size = int(shape[1] * individual_depth) + bits = bits[:, :original_row_size] + bits = np.packbits(bits, axis=-1, bitorder="little") + final = (bits - 128).astype(np.int8) + return torch.from_numpy(final) def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor: @@ -206,3 +246,27 @@ def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor: final = repacked.astype(np.int8) - 8 return torch.from_numpy(final) + + +def _pack_bits(bits_to_pack: torch.Tensor) -> torch.Tensor: + """ + Pack a tensor of bits to int32. + + :param bits_to_pack: tensor of bits to pack + """ + # pad each row to fill a full 32bit int + pack_depth = 32 + padding = ( + math.ceil(bits_to_pack.shape[1] / pack_depth) * pack_depth + - bits_to_pack.shape[1] + ) + padded_bits = np.pad( + bits_to_pack, pad_width=[(0, 0), (0, padding)], constant_values=0 + ) + + # after packbits each uint8 is two packed uint4s + # then we keep the bit pattern the same but convert to int32 + compressed = np.packbits(padded_bits, axis=-1, bitorder="little") + compressed = np.ascontiguousarray(compressed).view(np.int32) + + return torch.from_numpy(compressed) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index f31bdc78b..fe9fc8d43 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -96,7 +96,7 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str): ) -def apply_quantization_config(model: Module, config: QuantizationConfig): +def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict: """ Initializes the model for quantization in-place based on the given config @@ -106,6 +106,7 @@ def apply_quantization_config(model: Module, config: QuantizationConfig): # build mapping of targets to schemes for easier matching # use ordered dict to preserve target ordering in config target_to_scheme = OrderedDict() + names_to_scheme = OrderedDict() for scheme in config.config_groups.values(): for target in scheme.targets: target_to_scheme[target] = scheme @@ -123,6 +124,7 @@ def apply_quantization_config(model: Module, config: QuantizationConfig): if target is not None: # target matched - add layer and scheme to target list submodule.quantization_scheme = target_to_scheme[target] + names_to_scheme[name] = submodule.quantization_scheme.weights if config.ignore is not None and ignored_submodules is not None: if set(config.ignore) - set(ignored_submodules): @@ -132,7 +134,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig): f"{set(config.ignore) - set(ignored_submodules)}" ) # apply current quantization status across all targeted layers + apply_quantization_status(model, config.quantization_status) + return names_to_scheme def apply_quantization_status(model: Module, status: QuantizationStatus): diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 3e86a9fc3..077cac256 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from typing import Optional from transformers import AutoConfig diff --git a/tests/test_compressors/test_fp8_quant.py b/tests/test_compressors/test_fp8_quant.py index eb21b575c..62769507c 100644 --- a/tests/test_compressors/test_fp8_quant.py +++ b/tests/test_compressors/test_fp8_quant.py @@ -79,7 +79,7 @@ def test_quant_format(strategy, group_size, sc, zp): compressor = FloatQuantizationCompressor(config=quant_config) quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} compressed_state_dict = compressor.compress( - dense_state_dict, model_quant_args=quantized_modules_to_args + dense_state_dict, names_to_scheme=quantized_modules_to_args ) # state_dict params should be the same, minus the zero_point if symmetric @@ -118,7 +118,7 @@ def test_reload_match(strategy, group_size, tmp_path): "dummy": quant_config.config_groups["group_1"].weights, } compressed_state_dict = compressor.compress( - model.state_dict(), model_quant_args=quantized_modules_to_args + model.state_dict(), names_to_scheme=quantized_modules_to_args ) save_file(compressed_state_dict, tmp_path / "model.safetensors") reconstructed_dense_gen = compressor.decompress(tmp_path) diff --git a/tests/test_compressors/test_int_quant.py b/tests/test_compressors/test_int_quant.py index 4adce2a79..f7a2a33cb 100644 --- a/tests/test_compressors/test_int_quant.py +++ b/tests/test_compressors/test_int_quant.py @@ -74,7 +74,7 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp): compressor = IntQuantizationCompressor(config=quant_config) quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} compressed_state_dict = compressor.compress( - dense_state_dict, model_quant_args=quantized_modules_to_args + dense_state_dict, names_to_scheme=quantized_modules_to_args ) # state_dict params should be the same, minus the zero_point if symmetric @@ -125,7 +125,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path): "dummy2": quant_config.config_groups["group_1"].weights, } compressed_state_dict = compressor.compress( - dense_state_dict, model_quant_args=quantized_modules_to_args + dense_state_dict, names_to_scheme=quantized_modules_to_args ) save_file(compressed_state_dict, tmp_path / "model.safetensors") reconstructed_dense_gen = compressor.decompress(tmp_path) diff --git a/tests/test_compressors/test_pack_quant.py b/tests/test_compressors/test_pack_quant.py index 46bf091e5..e5bf82194 100644 --- a/tests/test_compressors/test_pack_quant.py +++ b/tests/test_compressors/test_pack_quant.py @@ -32,10 +32,10 @@ from safetensors.torch import save_file -def get_dummy_quant_config(): +def get_dummy_quant_config(num_bits=4): config_groups = { "group_1": QuantizationScheme( - targets=["Linear"], weights=QuantizationArgs(num_bits=4) + targets=["Linear"], weights=QuantizationArgs(num_bits=num_bits) ), } ignore = ["lm_head"] @@ -67,7 +67,7 @@ def test_quant_format(shape): compressor = PackedQuantizationCompressor(config=quant_config) quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} compressed_state_dict = compressor.compress( - dense_state_dict, model_quant_args=quantized_modules_to_args + dense_state_dict, names_to_scheme=quantized_modules_to_args ) # compressed state_dict adds one entry for shape @@ -106,7 +106,8 @@ def test_repack(value): assert torch.equal(value, unpacked) -def test_reload_match(tmp_path): +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_reload_match(tmp_path, num_bits): dense_state_dict = { "dummy.weight": torch.rand((511, 350)), "dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32), @@ -115,7 +116,12 @@ def test_reload_match(tmp_path): "dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32), "dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int8), } - quant_config = get_dummy_quant_config() + + names_to_scheme = { + "dummy": QuantizationArgs(num_bits=num_bits), + "dummy2": QuantizationArgs(num_bits=num_bits), + } + quant_config = get_dummy_quant_config(num_bits) compressor = PackedQuantizationCompressor(config=quant_config) quantized_modules_to_args = { @@ -123,10 +129,12 @@ def test_reload_match(tmp_path): "dummy2": quant_config.config_groups["group_1"].weights, } compressed_state_dict = compressor.compress( - dense_state_dict, model_quant_args=quantized_modules_to_args + dense_state_dict, names_to_scheme=quantized_modules_to_args ) save_file(compressed_state_dict, tmp_path / "model.safetensors") - reconstructed_dense_gen = compressor.decompress(tmp_path) + reconstructed_dense_gen = compressor.decompress( + tmp_path, names_to_scheme=names_to_scheme + ) reconstructed_dense = {} for name, value in reconstructed_dense_gen: reconstructed_dense[name] = value