Skip to content

Commit f3b0948

Browse files
authored
[compressor] Add packed int8 support (#91)
* add function to pack bits * fix arg * make 4bits the default * update * add support for int8 decompress; update function to take in name to scheme mapping * update to test 8 bits; update kwargs * fix print; update name * update tests * update arg * update all other classes
1 parent 6319bc1 commit f3b0948

File tree

10 files changed

+121
-44
lines changed

10 files changed

+121
-44
lines changed

src/compressed_tensors/compressors/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor
4545
raise NotImplementedError()
4646

4747
def decompress(
48-
self, path_to_model_or_tensors: str, device: str = "cpu"
48+
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
4949
) -> Generator[Tuple[str, Tensor], None, None]:
5050
"""
5151
Reads a compressed state dict located at path_to_model_or_tensors

src/compressed_tensors/compressors/marlin_24.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,19 @@ def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
107107
def compress(
108108
self,
109109
model_state: Dict[str, Tensor],
110-
model_quant_args: Dict[str, QuantizationArgs],
110+
names_to_scheme: Dict[str, QuantizationArgs],
111111
**kwargs,
112112
) -> Dict[str, Tensor]:
113113
"""
114114
Compresses a quantized state_dict with 2:4 sparsity structure for inference
115115
with the Marlin24 kernel
116116
117117
:param model_state: state dict of uncompressed model
118-
:param model_quant_args: quantization args for each quantized weight, needed for
118+
:param names_to_scheme: quantization args for each quantized weight, needed for
119119
quantize function to calculate bit depth
120120
:return: compressed state dict
121121
"""
122-
self.validate_quant_compatability(model_quant_args)
122+
self.validate_quant_compatability(names_to_scheme)
123123

124124
compressed_dict = {}
125125
weight_suffix = ".weight"
@@ -139,7 +139,7 @@ def compress(
139139
value = value.to(torch.float16)
140140

141141
# quantize weight, keeping it as a float16 for now
142-
quant_args = model_quant_args[prefix]
142+
quant_args = names_to_scheme[prefix]
143143
value = quantize(
144144
x=value, scale=scale, zero_point=zp, args=quant_args
145145
)

src/compressed_tensors/compressors/model_compressor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def compress(
231231
quantized_modules_to_args = map_modules_to_quant_args(model)
232232
if self.quantization_compressor is not None:
233233
compressed_state_dict = self.quantization_compressor.compress(
234-
state_dict, model_quant_args=quantized_modules_to_args
234+
state_dict, names_to_scheme=quantized_modules_to_args
235235
)
236236

237237
if self.sparsity_compressor is not None:
@@ -260,9 +260,11 @@ def decompress(self, model_path: str, model: Module):
260260
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
261261

262262
if self.quantization_compressor is not None:
263-
apply_quantization_config(model, self.quantization_config)
263+
names_to_scheme = apply_quantization_config(model, self.quantization_config)
264264
load_pretrained_quantization(model, model_path)
265-
dense_gen = self.quantization_compressor.decompress(model_path)
265+
dense_gen = self.quantization_compressor.decompress(
266+
model_path, names_to_scheme=names_to_scheme
267+
)
266268
self._replace_weights(dense_gen, model)
267269

268270
def update_status(module):

src/compressed_tensors/compressors/naive_quantized.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ class QuantizationCompressor(Compressor):
4949
def compress(
5050
self,
5151
model_state: Dict[str, Tensor],
52-
model_quant_args: Dict[str, QuantizationArgs],
52+
names_to_scheme: Dict[str, QuantizationArgs],
5353
**kwargs,
5454
) -> Dict[str, Tensor]:
5555
"""
5656
Compresses a dense state dict
5757
5858
:param model_state: state dict of uncompressed model
59-
:param model_quant_args: quantization args for each quantized weight, needed for
59+
:param names_to_scheme: quantization args for each quantized weight, needed for
6060
quantize function to calculate bit depth
6161
:return: compressed state dict
6262
"""
@@ -73,7 +73,7 @@ def compress(
7373
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
7474
if scale is not None and zp is not None:
7575
# weight is quantized, compress it
76-
quant_args = model_quant_args[prefix]
76+
quant_args = names_to_scheme[prefix]
7777
if can_quantize(value, quant_args):
7878
# only quantize if not already quantized
7979
value = quantize(

src/compressed_tensors/compressors/pack_quantized.py

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929
from tqdm import tqdm
3030

3131

32-
__all__ = ["PackedQuantizationCompressor", "pack_4bit_ints", "unpack_4bit_ints"]
32+
__all__ = [
33+
"PackedQuantizationCompressor",
34+
"pack_4bit_ints",
35+
"pack_8bit_ints",
36+
"unpack_4bit_ints",
37+
"unpack_8bit_ints",
38+
]
3339

3440
_LOGGER: logging.Logger = logging.getLogger(__name__)
3541

@@ -50,14 +56,14 @@ class PackedQuantizationCompressor(Compressor):
5056
def compress(
5157
self,
5258
model_state: Dict[str, Tensor],
53-
model_quant_args: Dict[str, QuantizationArgs],
59+
names_to_scheme: Dict[str, QuantizationArgs],
5460
**kwargs,
5561
) -> Dict[str, Tensor]:
5662
"""
5763
Compresses a dense state dict
5864
5965
:param model_state: state dict of uncompressed model
60-
:param model_quant_args: quantization args for each quantized weight, needed for
66+
:param names_to_scheme: quantization args for each quantized weight, needed for
6167
quantize function to calculate bit depth
6268
:return: compressed state dict
6369
"""
@@ -75,7 +81,7 @@ def compress(
7581
shape = torch.tensor(value.shape)
7682
if scale is not None and zp is not None:
7783
# weight is quantized, compress it
78-
quant_args = model_quant_args[prefix]
84+
quant_args = names_to_scheme[prefix]
7985
if can_quantize(value, quant_args):
8086
# convert weight to an int if not already compressed
8187
value = quantize(
@@ -85,7 +91,11 @@ def compress(
8591
args=quant_args,
8692
dtype=torch.int8,
8793
)
88-
value = pack_4bit_ints(value.cpu())
94+
95+
if quant_args.num_bits == 8:
96+
value = pack_8bit_ints(value.cpu())
97+
else:
98+
value = pack_4bit_ints(value.cpu())
8999
compressed_dict[merge_names(prefix, "weight_shape")] = shape
90100
compressed_dict[merge_names(prefix, "weight_packed")] = value
91101
continue
@@ -101,7 +111,10 @@ def compress(
101111
return compressed_dict
102112

103113
def decompress(
104-
self, path_to_model_or_tensors: str, device: str = "cpu"
114+
self,
115+
path_to_model_or_tensors: str,
116+
names_to_scheme: Dict[str, QuantizationArgs],
117+
device: str = "cpu",
105118
) -> Generator[Tuple[str, Tensor], None, None]:
106119
"""
107120
Reads a compressed state dict located at path_to_model_or_tensors
@@ -119,6 +132,7 @@ def decompress(
119132
for weight_name in weight_mappings.keys():
120133
weight_data = {}
121134
for param_name, safe_path in weight_mappings[weight_name].items():
135+
weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
122136
full_name = merge_names(weight_name, param_name)
123137
with safe_open(safe_path, framework="pt", device=device) as f:
124138
weight_data[param_name] = f.get_tensor(full_name)
@@ -127,8 +141,12 @@ def decompress(
127141
zero_point = weight_data.get("weight_zero_point", None)
128142
scale = weight_data["weight_scale"]
129143
weight = weight_data["weight_packed"]
144+
num_bits = weight_data["num_bits"]
130145
original_shape = torch.Size(weight_data["weight_shape"])
131-
unpacked = unpack_4bit_ints(weight, original_shape)
146+
if num_bits == 4:
147+
unpacked = unpack_4bit_ints(weight, original_shape)
148+
else:
149+
unpacked = unpack_8bit_ints(weight, original_shape)
132150
decompressed = dequantize(
133151
x_q=unpacked,
134152
scale=scale,
@@ -137,6 +155,19 @@ def decompress(
137155
yield merge_names(weight_name, "weight"), decompressed
138156

139157

158+
def pack_8bit_ints(value: torch.Tensor) -> torch.Tensor:
159+
"""
160+
Packs a tensor of int8 into int32s with padding
161+
162+
:param value: tensor to pack
163+
:returns: packed int32 tensor
164+
"""
165+
# need to convert to unsigned 8bit to use numpy's pack/unpack
166+
value_uint = (value - 128).to(torch.uint8)
167+
bits = np.unpackbits(value_uint, axis=-1, bitorder="little")
168+
return _pack_bits(bits_to_pack=bits)
169+
170+
140171
def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor:
141172
"""
142173
Packs a tensor of int4 weights stored in int8 into int32s with padding
@@ -152,22 +183,31 @@ def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor:
152183
bits = np.unpackbits(temp.numpy(), axis=-1, bitorder="little")
153184
ranges = np.array([range(x, x + 4) for x in range(0, bits.shape[1], 8)]).flatten()
154185
only_4_bits = bits[:, ranges] # top 4 bits are 0 because we're really uint4
186+
return _pack_bits(bits_to_pack=only_4_bits)
155187

156-
# pad each row to fill a full 32bit int
157-
pack_depth = 32
158-
padding = (
159-
math.ceil(only_4_bits.shape[1] / pack_depth) * pack_depth - only_4_bits.shape[1]
160-
)
161-
padded_bits = np.pad(
162-
only_4_bits, pad_width=[(0, 0), (0, padding)], constant_values=0
163-
)
164188

165-
# after packbits each uint8 is two packed uint4s
166-
# then we keep the bit pattern the same but convert to int32
167-
compressed = np.packbits(padded_bits, axis=-1, bitorder="little")
168-
compressed = np.ascontiguousarray(compressed).view(np.int32)
189+
def unpack_8bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
190+
"""
191+
Unpacks a tensor packed int8 weights in int32
169192
170-
return torch.from_numpy(compressed)
193+
:param value: tensor to upack
194+
:param shape: shape to unpack into, used to remove padding
195+
:returns: unpacked int8 tensor
196+
"""
197+
if value.dtype is not torch.int32:
198+
raise ValueError(
199+
f"Expected {torch.int32} but got {value.dtype}, Aborting unpack."
200+
)
201+
202+
# unpack bits and undo padding to nearest int32 bits
203+
individual_depth = 8
204+
as_uint8 = value.numpy().view(np.uint8)
205+
bits = np.unpackbits(as_uint8, axis=-1, bitorder="little")
206+
original_row_size = int(shape[1] * individual_depth)
207+
bits = bits[:, :original_row_size]
208+
bits = np.packbits(bits, axis=-1, bitorder="little")
209+
final = (bits - 128).astype(np.int8)
210+
return torch.from_numpy(final)
171211

172212

173213
def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
@@ -206,3 +246,27 @@ def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
206246
final = repacked.astype(np.int8) - 8
207247

208248
return torch.from_numpy(final)
249+
250+
251+
def _pack_bits(bits_to_pack: torch.Tensor) -> torch.Tensor:
252+
"""
253+
Pack a tensor of bits to int32.
254+
255+
:param bits_to_pack: tensor of bits to pack
256+
"""
257+
# pad each row to fill a full 32bit int
258+
pack_depth = 32
259+
padding = (
260+
math.ceil(bits_to_pack.shape[1] / pack_depth) * pack_depth
261+
- bits_to_pack.shape[1]
262+
)
263+
padded_bits = np.pad(
264+
bits_to_pack, pad_width=[(0, 0), (0, padding)], constant_values=0
265+
)
266+
267+
# after packbits each uint8 is two packed uint4s
268+
# then we keep the bit pattern the same but convert to int32
269+
compressed = np.packbits(padded_bits, axis=-1, bitorder="little")
270+
compressed = np.ascontiguousarray(compressed).view(np.int32)
271+
272+
return torch.from_numpy(compressed)

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
9696
)
9797

9898

99-
def apply_quantization_config(model: Module, config: QuantizationConfig):
99+
def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict:
100100
"""
101101
Initializes the model for quantization in-place based on the given config
102102
@@ -106,6 +106,7 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
106106
# build mapping of targets to schemes for easier matching
107107
# use ordered dict to preserve target ordering in config
108108
target_to_scheme = OrderedDict()
109+
names_to_scheme = OrderedDict()
109110
for scheme in config.config_groups.values():
110111
for target in scheme.targets:
111112
target_to_scheme[target] = scheme
@@ -123,6 +124,7 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
123124
if target is not None:
124125
# target matched - add layer and scheme to target list
125126
submodule.quantization_scheme = target_to_scheme[target]
127+
names_to_scheme[name] = submodule.quantization_scheme.weights
126128

127129
if config.ignore is not None and ignored_submodules is not None:
128130
if set(config.ignore) - set(ignored_submodules):
@@ -132,7 +134,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
132134
f"{set(config.ignore) - set(ignored_submodules)}"
133135
)
134136
# apply current quantization status across all targeted layers
137+
135138
apply_quantization_status(model, config.quantization_status)
139+
return names_to_scheme
136140

137141

138142
def apply_quantization_status(model: Module, status: QuantizationStatus):

src/compressed_tensors/utils/helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
from typing import Optional
1716

1817
from transformers import AutoConfig

tests/test_compressors/test_fp8_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_quant_format(strategy, group_size, sc, zp):
7979
compressor = FloatQuantizationCompressor(config=quant_config)
8080
quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}
8181
compressed_state_dict = compressor.compress(
82-
dense_state_dict, model_quant_args=quantized_modules_to_args
82+
dense_state_dict, names_to_scheme=quantized_modules_to_args
8383
)
8484

8585
# state_dict params should be the same, minus the zero_point if symmetric
@@ -118,7 +118,7 @@ def test_reload_match(strategy, group_size, tmp_path):
118118
"dummy": quant_config.config_groups["group_1"].weights,
119119
}
120120
compressed_state_dict = compressor.compress(
121-
model.state_dict(), model_quant_args=quantized_modules_to_args
121+
model.state_dict(), names_to_scheme=quantized_modules_to_args
122122
)
123123
save_file(compressed_state_dict, tmp_path / "model.safetensors")
124124
reconstructed_dense_gen = compressor.decompress(tmp_path)

tests/test_compressors/test_int_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
7474
compressor = IntQuantizationCompressor(config=quant_config)
7575
quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}
7676
compressed_state_dict = compressor.compress(
77-
dense_state_dict, model_quant_args=quantized_modules_to_args
77+
dense_state_dict, names_to_scheme=quantized_modules_to_args
7878
)
7979

8080
# state_dict params should be the same, minus the zero_point if symmetric
@@ -125,7 +125,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
125125
"dummy2": quant_config.config_groups["group_1"].weights,
126126
}
127127
compressed_state_dict = compressor.compress(
128-
dense_state_dict, model_quant_args=quantized_modules_to_args
128+
dense_state_dict, names_to_scheme=quantized_modules_to_args
129129
)
130130
save_file(compressed_state_dict, tmp_path / "model.safetensors")
131131
reconstructed_dense_gen = compressor.decompress(tmp_path)

0 commit comments

Comments
 (0)