19
19
# since the model classes inherit torch.nn.Module.
20
20
import math
21
21
22
+ import numpy as np
22
23
import torch
23
24
from packaging .version import Version
24
25
from torch .autograd import Function
@@ -325,11 +326,89 @@ def __init__(
325
326
else :
326
327
self .g_idx = None
327
328
329
+ def pack_tensor_with_numpy (self , raw_tensor ):
330
+ raw_array = raw_tensor .cpu ().numpy ()
331
+ target_len = np .ceil (raw_array .shape [1 ] / self .n_pack ).astype (int )
332
+ target_dtype = torch .tensor (0 , dtype = self .compression_dtype ).numpy ().dtype
333
+ packed_array = np .zeros ((raw_array .shape [0 ], target_len ), dtype = target_dtype )
334
+ mask = np .uint8 (2 ** self .bits - 1 )
335
+ for j in range (packed_array .shape [1 ]):
336
+ start = self .n_pack * j
337
+ end = self .n_pack * (j + 1 )
338
+ tmp = raw_array [:, start :end ].astype (target_dtype )
339
+ tmp &= mask
340
+ for e in range (tmp .shape [1 ]):
341
+ tmp [:, e ] = np .left_shift (tmp [:, e ], self .bits * e )
342
+ packed_array [:, j ] |= tmp [:, e ]
343
+ packed_tensor = torch .from_numpy (packed_array ).to (device = raw_tensor .device )
344
+ return packed_tensor
345
+
346
+ def unpack_tensor_with_numpy (self , packed_tensor ):
347
+ packed_array = packed_tensor .cpu ().numpy ()
348
+ target_dtype = np .int8 if not hasattr (self , "qzeros" ) or "int" not in self .dtype else np .uint8
349
+ target_len = packed_array .shape [1 ] * self .n_pack
350
+ unpacked_array = np .zeros ((packed_array .shape [0 ], target_len ), dtype = target_dtype )
351
+ mask = np .uint8 (2 ** self .bits - 1 )
352
+ for j in range (packed_array .shape [1 ]):
353
+ for e in range (self .n_pack ):
354
+ index = j * self .n_pack + e
355
+ tmp = packed_array [:, j ]
356
+ tmp = np .left_shift (tmp , self .compress_bits - self .bits * (e + 1 ))
357
+ tmp = np .right_shift (tmp , self .compress_bits - self .bits )
358
+ if target_dtype == np .uint8 :
359
+ tmp &= mask
360
+ unpacked_array [:, index ] = tmp .astype (target_dtype )
361
+ unpacked_tensor = torch .from_numpy (unpacked_array ).to (device = packed_tensor .device )
362
+ return unpacked_tensor
363
+
364
+ def pack_tensor_with_torch (self , raw_tensor ):
365
+ target_len = math .ceil (raw_tensor .shape [1 ] / self .n_pack )
366
+ packed_tensor = torch .zeros (raw_tensor .shape [0 ], target_len , dtype = self .compression_dtype ).to (self .device )
367
+ mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
368
+ for j in range (packed_tensor .shape [1 ]):
369
+ start = self .n_pack * j
370
+ end = self .n_pack * (j + 1 )
371
+ tmp = raw_tensor [:, start :end ].type (self .compression_dtype )
372
+ tmp &= mask
373
+ for e in range (tmp .shape [1 ]):
374
+ tmp [:, e ] = tmp [:, e ] << (self .bits * e )
375
+ packed_tensor [:, j ] |= tmp [:, e ]
376
+ return packed_tensor
377
+
378
+ def unpack_tensor_with_torch (self , packed_tensor ):
379
+ target_dtype = torch .int8 if not hasattr (self , "qzeros" ) or "int" not in self .dtype else torch .uint8
380
+ target_len = packed_tensor .shape [1 ] * self .n_pack
381
+ unpacked_tensor = torch .zeros (packed_tensor .shape [0 ], target_len , dtype = target_dtype ).to (self .device )
382
+ mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
383
+ for j in range (packed_tensor .shape [1 ]):
384
+ for e in range (self .n_pack ):
385
+ index = j * self .n_pack + e
386
+ tmp = packed_tensor [:, j ]
387
+ tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
388
+ tmp = tmp >> self .compress_bits - self .bits
389
+ if target_dtype == torch .uint8 :
390
+ tmp &= mask # remove sign bit
391
+ unpacked_tensor [:, index ].copy_ (tmp .type (target_dtype ))
392
+ logger .info (f"*****{ unpacked_tensor } " )
393
+ return unpacked_tensor
394
+
395
+ def pack_tensor (self , raw_tensor ):
396
+ if "cuda" in self .device :
397
+ return self .pack_tensor_with_torch (raw_tensor )
398
+ else :
399
+ return self .pack_tensor_with_numpy (raw_tensor )
400
+
401
+ def unpack_tensor (self , packed_tensor ):
402
+ if "cuda" in self .device :
403
+ return self .unpack_tensor_with_torch (packed_tensor )
404
+ else :
405
+ return self .unpack_tensor_with_numpy (packed_tensor )
406
+
328
407
def pack (self , int_weight , scale , zp , bias , g_idx = None ):
329
408
if self .use_optimum_format :
330
- self .scales = self .scales .t_ () .contiguous ()
331
- self .qweight = self .qweight .t_ () .contiguous ()
332
- self .qzeros = self .qzeros .t_ () .contiguous ()
409
+ self .scales = self .scales .T .contiguous ()
410
+ self .qweight = self .qweight .T .contiguous ()
411
+ self .qzeros = self .qzeros .T .contiguous ()
333
412
int_weight = int_weight .to (self .device )
334
413
if self .use_optimum_format and zp is None :
335
414
# to avoid overflow
@@ -350,118 +429,73 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
350
429
assert scale .shape == self .scales .shape , "Scale shape is mismatched."
351
430
self .scales = scale .type (self .float_type ).to (self .device )
352
431
if not self .use_optimum_format and self .compression_dim == 0 :
353
- int_weight = int_weight .t_ () .contiguous ()
354
- self .qweight = self .qweight .t_ () .contiguous ()
432
+ int_weight = int_weight .T .contiguous ()
433
+ self .qweight = self .qweight .T .contiguous ()
355
434
origin_shape = int_weight .shape
356
435
target_shape = self .qweight .shape
357
436
assert origin_shape [0 ] == target_shape [0 ], "output channels mismatch, please check."
358
- mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
359
437
360
438
# pack weight
361
- for j in range (target_shape [1 ]):
362
- start = self .n_pack * j
363
- end = self .n_pack * (j + 1 )
364
- tmp = int_weight [:, start :end ].type (self .compression_dtype )
365
- for e in range (tmp .shape [1 ]):
366
- tmp [:, e ] &= mask
367
- tmp [:, e ] = tmp [:, e ] << (self .bits * e )
368
- self .qweight [:, j ] |= tmp [:, e ]
439
+ self .qweight .copy_ (self .pack_tensor (int_weight ))
369
440
if not self .use_optimum_format and self .compression_dim == 0 :
370
- self .qweight = self .qweight .t_ () .contiguous ()
441
+ self .qweight = self .qweight .T .contiguous ()
371
442
372
443
if zp is not None :
373
444
zp = zp .to (self .device )
374
445
if self .use_optimum_format :
375
446
zp -= 1
376
447
if self .use_optimum_format or self .compression_dim == 0 :
377
- zp = zp .t_ () .contiguous ()
378
- self .qzeros = self .qzeros .t_ () .contiguous ()
448
+ zp = zp .T .contiguous ()
449
+ self .qzeros = self .qzeros .T .contiguous ()
379
450
assert hasattr (self , "qzeros" ), "zp is not set when initializing."
380
- target_shape = self .qzeros .shape
381
- for j in range (target_shape [1 ]):
382
- start = self .n_pack * j
383
- end = self .n_pack * (j + 1 )
384
- tmp = zp [:, start :end ].type (self .compression_dtype )
385
- for e in range (tmp .shape [1 ]):
386
- tmp [:, e ] &= mask
387
- tmp [:, e ] = tmp [:, e ] << (self .bits * e )
388
- self .qzeros [:, j ] |= tmp [:, e ]
451
+ self .qzeros .copy_ (self .pack_tensor (zp ))
389
452
if self .use_optimum_format or self .compression_dim == 0 :
390
- self .qzeros = self .qzeros .t_ () .contiguous ()
453
+ self .qzeros = self .qzeros .T .contiguous ()
391
454
if self .use_optimum_format :
392
- self .scales = self .scales .t_ () .contiguous ()
393
- self .qweight = self .qweight .t_ () .contiguous ()
394
- self .qzeros = self .qzeros .t_ () .contiguous ()
455
+ self .scales = self .scales .T .contiguous ()
456
+ self .qweight = self .qweight .T .contiguous ()
457
+ self .qzeros = self .qzeros .T .contiguous ()
395
458
396
459
def recover (self ):
397
460
logger .debug (f"Recovering { self } weight" )
398
- scales = self .scales .t_ () .contiguous () if self .use_optimum_format else self .scales
399
- qweight = self .qweight .t_ () .contiguous () if self .use_optimum_format else self .qweight
461
+ scales = self .scales .T .contiguous () if self .use_optimum_format else self .scales
462
+ qweight = self .qweight .T .contiguous () if self .use_optimum_format else self .qweight
400
463
401
464
device = scales .device
402
465
fp32_weight = torch .zeros (self .out_features , self .in_features , dtype = self .float_type ).to (device )
403
466
if self .g_idx is None :
404
467
# used for recovering fp32_weight
405
468
self .g_idx = torch .tensor ([i // self .groupsize for i in range (self .in_features )], dtype = torch .int32 )
406
- mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (device )
407
- if hasattr (self , "qzeros" ):
408
- weight_dtype = torch .uint8
409
- else :
410
- weight_dtype = torch .int8
411
469
# unpack weight
412
- weight = torch .zeros (self .out_features , self .in_features , dtype = weight_dtype ).to (device )
413
470
if not self .use_optimum_format and self .compression_dim == 0 :
414
- weight = weight .t_ ().contiguous ()
415
- qweight = qweight .t_ ().contiguous ()
416
- origin_shape = weight .shape
417
- target_shape = qweight .shape
418
- for j in range (target_shape [1 ]):
419
- for e in range (self .n_pack ):
420
- index = j * self .n_pack + e
421
- if index >= origin_shape [1 ]:
422
- continue
423
- tmp = qweight [:, j ]
424
- tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
425
- tmp = tmp >> self .compress_bits - self .bits
426
- if weight_dtype == torch .uint8 :
427
- tmp &= mask # remove sign bit
428
- weight [:, index ] = tmp .type (weight_dtype )
471
+ qweight = qweight .T .contiguous ()
472
+ weight = self .unpack_tensor (qweight )
429
473
if not self .use_optimum_format and self .compression_dim == 0 :
430
- weight = weight .t_ ().contiguous ()
474
+ weight = weight .T .contiguous ()
475
+ weight = weight [: self .out_features , : self .in_features ] # avoid oversize
431
476
if "int" not in self .dtype :
432
477
new_weight = torch .zeros (self .out_features , self .in_features ).to (device )
433
478
for k , v in self .int2float_mapping .items ():
434
479
new_weight += torch .where (weight == k , v , 0 )
435
480
weight = new_weight
436
481
# unpack zero_point
437
482
if hasattr (self , "qzeros" ):
438
- zp_dtype = self .compression_dtype # to avoid overflow when weight-zp
439
- zp = torch .zeros (scales .shape , dtype = zp_dtype ).to (device )
440
- qzeros = self .qzeros .t_ ().contiguous () if self .use_optimum_format else self .qzeros
483
+ qzeros = self .qzeros .T .contiguous () if self .use_optimum_format else self .qzeros
441
484
if self .use_optimum_format or self .compression_dim == 0 :
442
- zp = zp .t_ ().contiguous ()
443
- qzeros = qzeros .t_ ().contiguous ()
444
- origin_shape = zp .shape
445
- target_shape = qzeros .shape
446
- for j in range (target_shape [1 ]):
447
- for e in range (self .n_pack ):
448
- index = j * self .n_pack + e
449
- if index >= origin_shape [1 ]:
450
- continue
451
- tmp = qzeros [:, j ]
452
- tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
453
- tmp = tmp >> self .compress_bits - self .bits
454
- tmp &= mask
455
- zp [:, index ] = tmp .type (zp_dtype )
485
+ qzeros = qzeros .T .contiguous ()
486
+ zp = self .unpack_tensor (qzeros )
456
487
if self .use_optimum_format or self .compression_dim == 0 :
457
- zp = zp .t_ ().contiguous ()
488
+ zp = zp .T .contiguous ()
489
+ zp = zp [: scales .shape [0 ], : scales .shape [1 ]] # avoid oversize
458
490
if self .use_optimum_format :
459
491
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
460
492
zp += 1
461
493
zp = torch .where (zp > (2 ** self .bits - 1 ), 0 , zp )
462
494
# recover fp32 weight with int_weight, scale, and zero_point
463
495
for idx in range (self .in_features ):
464
- fp32_weight [:, idx ] = (weight [:, idx ] - zp [:, self .g_idx [idx ]]) * scales [:, self .g_idx [idx ]]
496
+ fp32_weight [:, idx ] = (torch .subtract (weight [:, idx ], zp [:, self .g_idx [idx ]]).to (torch .int8 )) * scales [
497
+ :, self .g_idx [idx ]
498
+ ]
465
499
else :
466
500
# recover fp32 weight with int_weight, scale
467
501
for idx in range (self .in_features ):
0 commit comments