@@ -127,23 +127,20 @@ def __init__(
127127 dtype = self .float_type ,
128128 ).to (device ),
129129 )
130- self .scales = self .scales .T
131130 self .register_buffer (
132131 "qweight" ,
133132 torch .zeros (
134133 (math .ceil (in_features / self .n_pack ), out_features ),
135134 dtype = self .compression_dtype ,
136135 ).to (device ),
137136 )
138- self .qweight = self .qweight .T
139137 self .register_buffer (
140138 "qzeros" ,
141139 torch .zeros (
142140 (math .ceil (self .in_features / self .groupsize ), math .ceil (self .out_features / self .n_pack )),
143141 dtype = self .compression_dtype ,
144142 ).to (device ),
145143 )
146- self .qzeros = self .qzeros .T
147144 self .register_buffer ("bias" , torch .zeros (self .out_features , dtype = self .float_type ).to (device ))
148145 else :
149146 self .compression_dtype = compression_dtype
@@ -193,6 +190,10 @@ def __init__(
193190 self .bias = None
194191
195192 def pack (self , int_weight , scale , zp , bias ):
193+ if self .use_optimum_format :
194+ self .scales = self .scales .t_ ().contiguous ()
195+ self .qweight = self .qweight .t_ ().contiguous ()
196+ self .qzeros = self .qzeros .t_ ().contiguous ()
196197 int_weight = int_weight .to (self .device )
197198 if self .use_optimum_format and zp is None :
198199 # to avoid overflow
@@ -206,8 +207,8 @@ def pack(self, int_weight, scale, zp, bias):
206207 assert scale .shape == self .scales .shape , "Scale shape is mismatched."
207208 self .scales = scale .type (self .float_type ).to (self .device )
208209 if not self .use_optimum_format and self .compression_dim == 0 :
209- int_weight = int_weight .T
210- self .qweight = self .qweight .T
210+ int_weight = int_weight .t_ (). contiguous ()
211+ self .qweight = self .qweight .t_ (). contiguous ()
211212 origin_shape = int_weight .shape
212213 target_shape = self .qweight .shape
213214 assert origin_shape [0 ] == target_shape [0 ], "output channels mismatch, please check."
@@ -223,15 +224,15 @@ def pack(self, int_weight, scale, zp, bias):
223224 tmp [:, e ] = tmp [:, e ] << (self .bits * e )
224225 self .qweight [:, j ] |= tmp [:, e ]
225226 if not self .use_optimum_format and self .compression_dim == 0 :
226- self .qweight = self .qweight .T
227+ self .qweight = self .qweight .t_ (). contiguous ()
227228
228229 if zp is not None :
229230 zp = zp .to (self .device )
230231 if self .use_optimum_format :
231232 zp -= 1
232233 if self .use_optimum_format or self .compression_dim == 0 :
233- zp = zp .T
234- self .qzeros = self .qzeros .T
234+ zp = zp .t_ (). contiguous ()
235+ self .qzeros = self .qzeros .t_ (). contiguous ()
235236 assert hasattr (self , "qzeros" ), "zp is not set when initializing."
236237 target_shape = self .qzeros .shape
237238 for j in range (target_shape [1 ]):
@@ -243,16 +244,16 @@ def pack(self, int_weight, scale, zp, bias):
243244 tmp [:, e ] = tmp [:, e ] << (self .bits * e )
244245 self .qzeros [:, j ] |= tmp [:, e ]
245246 if self .use_optimum_format or self .compression_dim == 0 :
246- self .qzeros = self .qzeros .T
247+ self .qzeros = self .qzeros .t_ (). contiguous ()
247248 if self .use_optimum_format :
248- self .scales = self .scales .T
249- self .qweight = self .qweight .T
250- self .qzeros = self .qzeros .T
249+ self .scales = self .scales .t_ (). contiguous ()
250+ self .qweight = self .qweight .t_ (). contiguous ()
251+ self .qzeros = self .qzeros .t_ (). contiguous ()
251252
252253 def recover (self ):
253254 logger .debug (f"Recovering { self } weight" )
254- scales = self .scales .T if self .use_optimum_format else self .scales
255- qweight = self .qweight .T if self .use_optimum_format else self .qweight
255+ scales = self .scales .t_ (). contiguous () if self .use_optimum_format else self .scales
256+ qweight = self .qweight .t_ (). contiguous () if self .use_optimum_format else self .qweight
256257
257258 device = scales .device
258259 fp32_weight = torch .zeros (self .out_features , self .in_features , dtype = self .float_type ).to (device )
@@ -264,8 +265,8 @@ def recover(self):
264265 # unpack weight
265266 weight = torch .zeros (self .out_features , self .in_features , dtype = weight_dtype ).to (device )
266267 if not self .use_optimum_format and self .compression_dim == 0 :
267- weight = weight .T
268- qweight = qweight .T
268+ weight = weight .t_ (). contiguous ()
269+ qweight = qweight .t_ (). contiguous ()
269270 origin_shape = weight .shape
270271 target_shape = qweight .shape
271272 for j in range (target_shape [1 ]):
@@ -280,7 +281,7 @@ def recover(self):
280281 tmp &= mask # remove sign bit
281282 weight [:, index ] = tmp .type (weight_dtype )
282283 if not self .use_optimum_format and self .compression_dim == 0 :
283- weight = weight .T
284+ weight = weight .t_ (). contiguous ()
284285 if "int" not in self .dtype :
285286 new_weight = torch .zeros (self .out_features , self .in_features ).to (device )
286287 for k , v in self .int2float_mapping .items ():
@@ -290,10 +291,10 @@ def recover(self):
290291 if hasattr (self , "qzeros" ):
291292 zp_dtype = self .compression_dtype # to avoid overflow when weight-zp
292293 zp = torch .zeros (scales .shape , dtype = zp_dtype ).to (device )
293- qzeros = self .qzeros .T if self .use_optimum_format else self .qzeros
294+ qzeros = self .qzeros .t_ (). contiguous () if self .use_optimum_format else self .qzeros
294295 if self .use_optimum_format or self .compression_dim == 0 :
295- zp = zp .T
296- qzeros = qzeros .T
296+ zp = zp .t_ (). contiguous ()
297+ qzeros = qzeros .t_ (). contiguous ()
297298 origin_shape = zp .shape
298299 target_shape = qzeros .shape
299300 for j in range (target_shape [1 ]):
@@ -307,7 +308,7 @@ def recover(self):
307308 tmp &= mask
308309 zp [:, index ] = tmp .type (zp_dtype )
309310 if self .use_optimum_format or self .compression_dim == 0 :
310- zp = zp .T
311+ zp = zp .t_ (). contiguous ()
311312 if self .use_optimum_format :
312313 # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
313314 zp += 1
0 commit comments