2929from  tqdm  import  tqdm 
3030
3131
32- __all__  =  ["PackedQuantizationCompressor" , "pack_4bit_ints" , "unpack_4bit_ints" ]
32+ __all__  =  [
33+     "PackedQuantizationCompressor" ,
34+     "pack_4bit_ints" ,
35+     "pack_8bit_ints" ,
36+     "unpack_4bit_ints" ,
37+     "unpack_8bit_ints" ,
38+ ]
3339
3440_LOGGER : logging .Logger  =  logging .getLogger (__name__ )
3541
@@ -50,14 +56,14 @@ class PackedQuantizationCompressor(Compressor):
5056    def  compress (
5157        self ,
5258        model_state : Dict [str , Tensor ],
53-         model_quant_args : Dict [str , QuantizationArgs ],
59+         names_to_scheme : Dict [str , QuantizationArgs ],
5460        ** kwargs ,
5561    ) ->  Dict [str , Tensor ]:
5662        """ 
5763        Compresses a dense state dict 
5864
5965        :param model_state: state dict of uncompressed model 
60-         :param model_quant_args : quantization args for each quantized weight, needed for 
66+         :param names_to_scheme : quantization args for each quantized weight, needed for 
6167        quantize function to calculate bit depth 
6268        :return: compressed state dict 
6369        """ 
@@ -75,7 +81,7 @@ def compress(
7581                shape  =  torch .tensor (value .shape )
7682                if  scale  is  not   None  and  zp  is  not   None :
7783                    # weight is quantized, compress it 
78-                     quant_args  =  model_quant_args [prefix ]
84+                     quant_args  =  names_to_scheme [prefix ]
7985                    if  can_quantize (value , quant_args ):
8086                        # convert weight to an int if not already compressed 
8187                        value  =  quantize (
@@ -85,7 +91,11 @@ def compress(
8591                            args = quant_args ,
8692                            dtype = torch .int8 ,
8793                        )
88-                     value  =  pack_4bit_ints (value .cpu ())
94+ 
95+                     if  quant_args .num_bits  ==  8 :
96+                         value  =  pack_8bit_ints (value .cpu ())
97+                     else :
98+                         value  =  pack_4bit_ints (value .cpu ())
8999                    compressed_dict [merge_names (prefix , "weight_shape" )] =  shape 
90100                    compressed_dict [merge_names (prefix , "weight_packed" )] =  value 
91101                    continue 
@@ -101,7 +111,10 @@ def compress(
101111        return  compressed_dict 
102112
103113    def  decompress (
104-         self , path_to_model_or_tensors : str , device : str  =  "cpu" 
114+         self ,
115+         path_to_model_or_tensors : str ,
116+         names_to_scheme : Dict [str , QuantizationArgs ],
117+         device : str  =  "cpu" ,
105118    ) ->  Generator [Tuple [str , Tensor ], None , None ]:
106119        """ 
107120        Reads a compressed state dict located at path_to_model_or_tensors 
@@ -119,6 +132,7 @@ def decompress(
119132        for  weight_name  in  weight_mappings .keys ():
120133            weight_data  =  {}
121134            for  param_name , safe_path  in  weight_mappings [weight_name ].items ():
135+                 weight_data ["num_bits" ] =  names_to_scheme .get (weight_name ).num_bits 
122136                full_name  =  merge_names (weight_name , param_name )
123137                with  safe_open (safe_path , framework = "pt" , device = device ) as  f :
124138                    weight_data [param_name ] =  f .get_tensor (full_name )
@@ -127,8 +141,12 @@ def decompress(
127141                zero_point  =  weight_data .get ("weight_zero_point" , None )
128142                scale  =  weight_data ["weight_scale" ]
129143                weight  =  weight_data ["weight_packed" ]
144+                 num_bits  =  weight_data ["num_bits" ]
130145                original_shape  =  torch .Size (weight_data ["weight_shape" ])
131-                 unpacked  =  unpack_4bit_ints (weight , original_shape )
146+                 if  num_bits  ==  4 :
147+                     unpacked  =  unpack_4bit_ints (weight , original_shape )
148+                 else :
149+                     unpacked  =  unpack_8bit_ints (weight , original_shape )
132150                decompressed  =  dequantize (
133151                    x_q = unpacked ,
134152                    scale = scale ,
@@ -137,6 +155,19 @@ def decompress(
137155                yield  merge_names (weight_name , "weight" ), decompressed 
138156
139157
158+ def  pack_8bit_ints (value : torch .Tensor ) ->  torch .Tensor :
159+     """ 
160+     Packs a tensor of int8 into int32s with padding 
161+ 
162+     :param value: tensor to pack 
163+     :returns: packed int32 tensor 
164+     """ 
165+     # need to convert to unsigned 8bit to use numpy's pack/unpack 
166+     value_uint  =  (value  -  128 ).to (torch .uint8 )
167+     bits  =  np .unpackbits (value_uint , axis = - 1 , bitorder = "little" )
168+     return  _pack_bits (bits_to_pack = bits )
169+ 
170+ 
140171def  pack_4bit_ints (value : torch .Tensor ) ->  torch .Tensor :
141172    """ 
142173    Packs a tensor of int4 weights stored in int8 into int32s with padding 
@@ -152,22 +183,31 @@ def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor:
152183    bits  =  np .unpackbits (temp .numpy (), axis = - 1 , bitorder = "little" )
153184    ranges  =  np .array ([range (x , x  +  4 ) for  x  in  range (0 , bits .shape [1 ], 8 )]).flatten ()
154185    only_4_bits  =  bits [:, ranges ]  # top 4 bits are 0 because we're really uint4 
186+     return  _pack_bits (bits_to_pack = only_4_bits )
155187
156-     # pad each row to fill a full 32bit int 
157-     pack_depth  =  32 
158-     padding  =  (
159-         math .ceil (only_4_bits .shape [1 ] /  pack_depth ) *  pack_depth  -  only_4_bits .shape [1 ]
160-     )
161-     padded_bits  =  np .pad (
162-         only_4_bits , pad_width = [(0 , 0 ), (0 , padding )], constant_values = 0 
163-     )
164188
165-     # after packbits each uint8 is two packed uint4s 
166-     # then we keep the bit pattern the same but convert to int32 
167-     compressed  =  np .packbits (padded_bits , axis = - 1 , bitorder = "little" )
168-     compressed  =  np .ascontiguousarray (compressed ).view (np .int32 )
189+ def  unpack_8bit_ints (value : torch .Tensor , shape : torch .Size ) ->  torch .Tensor :
190+     """ 
191+     Unpacks a tensor packed int8 weights in int32 
169192
170-     return  torch .from_numpy (compressed )
193+     :param value: tensor to upack 
194+     :param shape: shape to unpack into, used to remove padding 
195+     :returns: unpacked int8 tensor 
196+     """ 
197+     if  value .dtype  is  not   torch .int32 :
198+         raise  ValueError (
199+             f"Expected { torch .int32 }   but got { value .dtype }  , Aborting unpack." 
200+         )
201+ 
202+     # unpack bits and undo padding to nearest int32 bits 
203+     individual_depth  =  8 
204+     as_uint8  =  value .numpy ().view (np .uint8 )
205+     bits  =  np .unpackbits (as_uint8 , axis = - 1 , bitorder = "little" )
206+     original_row_size  =  int (shape [1 ] *  individual_depth )
207+     bits  =  bits [:, :original_row_size ]
208+     bits  =  np .packbits (bits , axis = - 1 , bitorder = "little" )
209+     final  =  (bits  -  128 ).astype (np .int8 )
210+     return  torch .from_numpy (final )
171211
172212
173213def  unpack_4bit_ints (value : torch .Tensor , shape : torch .Size ) ->  torch .Tensor :
@@ -206,3 +246,27 @@ def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
206246    final  =  repacked .astype (np .int8 ) -  8 
207247
208248    return  torch .from_numpy (final )
249+ 
250+ 
251+ def  _pack_bits (bits_to_pack : torch .Tensor ) ->  torch .Tensor :
252+     """ 
253+     Pack a tensor of bits to int32. 
254+ 
255+     :param bits_to_pack: tensor of bits to pack 
256+     """ 
257+     # pad each row to fill a full 32bit int 
258+     pack_depth  =  32 
259+     padding  =  (
260+         math .ceil (bits_to_pack .shape [1 ] /  pack_depth ) *  pack_depth 
261+         -  bits_to_pack .shape [1 ]
262+     )
263+     padded_bits  =  np .pad (
264+         bits_to_pack , pad_width = [(0 , 0 ), (0 , padding )], constant_values = 0 
265+     )
266+ 
267+     # after packbits each uint8 is two packed uint4s 
268+     # then we keep the bit pattern the same but convert to int32 
269+     compressed  =  np .packbits (padded_bits , axis = - 1 , bitorder = "little" )
270+     compressed  =  np .ascontiguousarray (compressed ).view (np .int32 )
271+ 
272+     return  torch .from_numpy (compressed )
0 commit comments