Skip to content

Commit daa1431

Browse files
authored
Fix WOQ Linear pack/unpack slow issue 2x (#1837)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 29fdecb commit daa1431

File tree

1 file changed

+107
-73
lines changed

1 file changed

+107
-73
lines changed

neural_compressor/adaptor/torch_utils/model_wrapper.py

Lines changed: 107 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# since the model classes inherit torch.nn.Module.
2020
import math
2121

22+
import numpy as np
2223
import torch
2324
from packaging.version import Version
2425
from torch.autograd import Function
@@ -325,11 +326,89 @@ def __init__(
325326
else:
326327
self.g_idx = None
327328

329+
def pack_tensor_with_numpy(self, raw_tensor):
330+
raw_array = raw_tensor.cpu().numpy()
331+
target_len = np.ceil(raw_array.shape[1] / self.n_pack).astype(int)
332+
target_dtype = torch.tensor(0, dtype=self.compression_dtype).numpy().dtype
333+
packed_array = np.zeros((raw_array.shape[0], target_len), dtype=target_dtype)
334+
mask = np.uint8(2**self.bits - 1)
335+
for j in range(packed_array.shape[1]):
336+
start = self.n_pack * j
337+
end = self.n_pack * (j + 1)
338+
tmp = raw_array[:, start:end].astype(target_dtype)
339+
tmp &= mask
340+
for e in range(tmp.shape[1]):
341+
tmp[:, e] = np.left_shift(tmp[:, e], self.bits * e)
342+
packed_array[:, j] |= tmp[:, e]
343+
packed_tensor = torch.from_numpy(packed_array).to(device=raw_tensor.device)
344+
return packed_tensor
345+
346+
def unpack_tensor_with_numpy(self, packed_tensor):
347+
packed_array = packed_tensor.cpu().numpy()
348+
target_dtype = np.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else np.uint8
349+
target_len = packed_array.shape[1] * self.n_pack
350+
unpacked_array = np.zeros((packed_array.shape[0], target_len), dtype=target_dtype)
351+
mask = np.uint8(2**self.bits - 1)
352+
for j in range(packed_array.shape[1]):
353+
for e in range(self.n_pack):
354+
index = j * self.n_pack + e
355+
tmp = packed_array[:, j]
356+
tmp = np.left_shift(tmp, self.compress_bits - self.bits * (e + 1))
357+
tmp = np.right_shift(tmp, self.compress_bits - self.bits)
358+
if target_dtype == np.uint8:
359+
tmp &= mask
360+
unpacked_array[:, index] = tmp.astype(target_dtype)
361+
unpacked_tensor = torch.from_numpy(unpacked_array).to(device=packed_tensor.device)
362+
return unpacked_tensor
363+
364+
def pack_tensor_with_torch(self, raw_tensor):
365+
target_len = math.ceil(raw_tensor.shape[1] / self.n_pack)
366+
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
367+
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
368+
for j in range(packed_tensor.shape[1]):
369+
start = self.n_pack * j
370+
end = self.n_pack * (j + 1)
371+
tmp = raw_tensor[:, start:end].type(self.compression_dtype)
372+
tmp &= mask
373+
for e in range(tmp.shape[1]):
374+
tmp[:, e] = tmp[:, e] << (self.bits * e)
375+
packed_tensor[:, j] |= tmp[:, e]
376+
return packed_tensor
377+
378+
def unpack_tensor_with_torch(self, packed_tensor):
379+
target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8
380+
target_len = packed_tensor.shape[1] * self.n_pack
381+
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(self.device)
382+
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
383+
for j in range(packed_tensor.shape[1]):
384+
for e in range(self.n_pack):
385+
index = j * self.n_pack + e
386+
tmp = packed_tensor[:, j]
387+
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
388+
tmp = tmp >> self.compress_bits - self.bits
389+
if target_dtype == torch.uint8:
390+
tmp &= mask # remove sign bit
391+
unpacked_tensor[:, index].copy_(tmp.type(target_dtype))
392+
logger.info(f"*****{unpacked_tensor}")
393+
return unpacked_tensor
394+
395+
def pack_tensor(self, raw_tensor):
396+
if "cuda" in self.device:
397+
return self.pack_tensor_with_torch(raw_tensor)
398+
else:
399+
return self.pack_tensor_with_numpy(raw_tensor)
400+
401+
def unpack_tensor(self, packed_tensor):
402+
if "cuda" in self.device:
403+
return self.unpack_tensor_with_torch(packed_tensor)
404+
else:
405+
return self.unpack_tensor_with_numpy(packed_tensor)
406+
328407
def pack(self, int_weight, scale, zp, bias, g_idx=None):
329408
if self.use_optimum_format:
330-
self.scales = self.scales.t_().contiguous()
331-
self.qweight = self.qweight.t_().contiguous()
332-
self.qzeros = self.qzeros.t_().contiguous()
409+
self.scales = self.scales.T.contiguous()
410+
self.qweight = self.qweight.T.contiguous()
411+
self.qzeros = self.qzeros.T.contiguous()
333412
int_weight = int_weight.to(self.device)
334413
if self.use_optimum_format and zp is None:
335414
# to avoid overflow
@@ -350,118 +429,73 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
350429
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
351430
self.scales = scale.type(self.float_type).to(self.device)
352431
if not self.use_optimum_format and self.compression_dim == 0:
353-
int_weight = int_weight.t_().contiguous()
354-
self.qweight = self.qweight.t_().contiguous()
432+
int_weight = int_weight.T.contiguous()
433+
self.qweight = self.qweight.T.contiguous()
355434
origin_shape = int_weight.shape
356435
target_shape = self.qweight.shape
357436
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
358-
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
359437

360438
# pack weight
361-
for j in range(target_shape[1]):
362-
start = self.n_pack * j
363-
end = self.n_pack * (j + 1)
364-
tmp = int_weight[:, start:end].type(self.compression_dtype)
365-
for e in range(tmp.shape[1]):
366-
tmp[:, e] &= mask
367-
tmp[:, e] = tmp[:, e] << (self.bits * e)
368-
self.qweight[:, j] |= tmp[:, e]
439+
self.qweight.copy_(self.pack_tensor(int_weight))
369440
if not self.use_optimum_format and self.compression_dim == 0:
370-
self.qweight = self.qweight.t_().contiguous()
441+
self.qweight = self.qweight.T.contiguous()
371442

372443
if zp is not None:
373444
zp = zp.to(self.device)
374445
if self.use_optimum_format:
375446
zp -= 1
376447
if self.use_optimum_format or self.compression_dim == 0:
377-
zp = zp.t_().contiguous()
378-
self.qzeros = self.qzeros.t_().contiguous()
448+
zp = zp.T.contiguous()
449+
self.qzeros = self.qzeros.T.contiguous()
379450
assert hasattr(self, "qzeros"), "zp is not set when initializing."
380-
target_shape = self.qzeros.shape
381-
for j in range(target_shape[1]):
382-
start = self.n_pack * j
383-
end = self.n_pack * (j + 1)
384-
tmp = zp[:, start:end].type(self.compression_dtype)
385-
for e in range(tmp.shape[1]):
386-
tmp[:, e] &= mask
387-
tmp[:, e] = tmp[:, e] << (self.bits * e)
388-
self.qzeros[:, j] |= tmp[:, e]
451+
self.qzeros.copy_(self.pack_tensor(zp))
389452
if self.use_optimum_format or self.compression_dim == 0:
390-
self.qzeros = self.qzeros.t_().contiguous()
453+
self.qzeros = self.qzeros.T.contiguous()
391454
if self.use_optimum_format:
392-
self.scales = self.scales.t_().contiguous()
393-
self.qweight = self.qweight.t_().contiguous()
394-
self.qzeros = self.qzeros.t_().contiguous()
455+
self.scales = self.scales.T.contiguous()
456+
self.qweight = self.qweight.T.contiguous()
457+
self.qzeros = self.qzeros.T.contiguous()
395458

396459
def recover(self):
397460
logger.debug(f"Recovering {self} weight")
398-
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
399-
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight
461+
scales = self.scales.T.contiguous() if self.use_optimum_format else self.scales
462+
qweight = self.qweight.T.contiguous() if self.use_optimum_format else self.qweight
400463

401464
device = scales.device
402465
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
403466
if self.g_idx is None:
404467
# used for recovering fp32_weight
405468
self.g_idx = torch.tensor([i // self.groupsize for i in range(self.in_features)], dtype=torch.int32)
406-
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(device)
407-
if hasattr(self, "qzeros"):
408-
weight_dtype = torch.uint8
409-
else:
410-
weight_dtype = torch.int8
411469
# unpack weight
412-
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
413470
if not self.use_optimum_format and self.compression_dim == 0:
414-
weight = weight.t_().contiguous()
415-
qweight = qweight.t_().contiguous()
416-
origin_shape = weight.shape
417-
target_shape = qweight.shape
418-
for j in range(target_shape[1]):
419-
for e in range(self.n_pack):
420-
index = j * self.n_pack + e
421-
if index >= origin_shape[1]:
422-
continue
423-
tmp = qweight[:, j]
424-
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
425-
tmp = tmp >> self.compress_bits - self.bits
426-
if weight_dtype == torch.uint8:
427-
tmp &= mask # remove sign bit
428-
weight[:, index] = tmp.type(weight_dtype)
471+
qweight = qweight.T.contiguous()
472+
weight = self.unpack_tensor(qweight)
429473
if not self.use_optimum_format and self.compression_dim == 0:
430-
weight = weight.t_().contiguous()
474+
weight = weight.T.contiguous()
475+
weight = weight[: self.out_features, : self.in_features] # avoid oversize
431476
if "int" not in self.dtype:
432477
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
433478
for k, v in self.int2float_mapping.items():
434479
new_weight += torch.where(weight == k, v, 0)
435480
weight = new_weight
436481
# unpack zero_point
437482
if hasattr(self, "qzeros"):
438-
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
439-
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
440-
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
483+
qzeros = self.qzeros.T.contiguous() if self.use_optimum_format else self.qzeros
441484
if self.use_optimum_format or self.compression_dim == 0:
442-
zp = zp.t_().contiguous()
443-
qzeros = qzeros.t_().contiguous()
444-
origin_shape = zp.shape
445-
target_shape = qzeros.shape
446-
for j in range(target_shape[1]):
447-
for e in range(self.n_pack):
448-
index = j * self.n_pack + e
449-
if index >= origin_shape[1]:
450-
continue
451-
tmp = qzeros[:, j]
452-
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
453-
tmp = tmp >> self.compress_bits - self.bits
454-
tmp &= mask
455-
zp[:, index] = tmp.type(zp_dtype)
485+
qzeros = qzeros.T.contiguous()
486+
zp = self.unpack_tensor(qzeros)
456487
if self.use_optimum_format or self.compression_dim == 0:
457-
zp = zp.t_().contiguous()
488+
zp = zp.T.contiguous()
489+
zp = zp[: scales.shape[0], : scales.shape[1]] # avoid oversize
458490
if self.use_optimum_format:
459491
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
460492
zp += 1
461493
zp = torch.where(zp > (2**self.bits - 1), 0, zp)
462494
# recover fp32 weight with int_weight, scale, and zero_point
463495
for idx in range(self.in_features):
464-
fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * scales[:, self.g_idx[idx]]
496+
fp32_weight[:, idx] = (torch.subtract(weight[:, idx], zp[:, self.g_idx[idx]]).to(torch.int8)) * scales[
497+
:, self.g_idx[idx]
498+
]
465499
else:
466500
# recover fp32 weight with int_weight, scale
467501
for idx in range(self.in_features):

0 commit comments

Comments
 (0)