Skip to content

Commit

Permalink
Fix WOQ Linear pack/unpack slow issue 2x (#1837)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
  • Loading branch information
Kaihui-intel authored Jun 5, 2024
1 parent 29fdecb commit daa1431
Showing 1 changed file with 107 additions and 73 deletions.
180 changes: 107 additions & 73 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# since the model classes inherit torch.nn.Module.
import math

import numpy as np
import torch
from packaging.version import Version
from torch.autograd import Function
Expand Down Expand Up @@ -325,11 +326,89 @@ def __init__(
else:
self.g_idx = None

def pack_tensor_with_numpy(self, raw_tensor):
raw_array = raw_tensor.cpu().numpy()
target_len = np.ceil(raw_array.shape[1] / self.n_pack).astype(int)
target_dtype = torch.tensor(0, dtype=self.compression_dtype).numpy().dtype
packed_array = np.zeros((raw_array.shape[0], target_len), dtype=target_dtype)
mask = np.uint8(2**self.bits - 1)
for j in range(packed_array.shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = raw_array[:, start:end].astype(target_dtype)
tmp &= mask
for e in range(tmp.shape[1]):
tmp[:, e] = np.left_shift(tmp[:, e], self.bits * e)
packed_array[:, j] |= tmp[:, e]
packed_tensor = torch.from_numpy(packed_array).to(device=raw_tensor.device)
return packed_tensor

def unpack_tensor_with_numpy(self, packed_tensor):
packed_array = packed_tensor.cpu().numpy()
target_dtype = np.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else np.uint8
target_len = packed_array.shape[1] * self.n_pack
unpacked_array = np.zeros((packed_array.shape[0], target_len), dtype=target_dtype)
mask = np.uint8(2**self.bits - 1)
for j in range(packed_array.shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
tmp = packed_array[:, j]
tmp = np.left_shift(tmp, self.compress_bits - self.bits * (e + 1))
tmp = np.right_shift(tmp, self.compress_bits - self.bits)
if target_dtype == np.uint8:
tmp &= mask
unpacked_array[:, index] = tmp.astype(target_dtype)
unpacked_tensor = torch.from_numpy(unpacked_array).to(device=packed_tensor.device)
return unpacked_tensor

def pack_tensor_with_torch(self, raw_tensor):
target_len = math.ceil(raw_tensor.shape[1] / self.n_pack)
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
for j in range(packed_tensor.shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = raw_tensor[:, start:end].type(self.compression_dtype)
tmp &= mask
for e in range(tmp.shape[1]):
tmp[:, e] = tmp[:, e] << (self.bits * e)
packed_tensor[:, j] |= tmp[:, e]
return packed_tensor

def unpack_tensor_with_torch(self, packed_tensor):
target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8
target_len = packed_tensor.shape[1] * self.n_pack
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(self.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
for j in range(packed_tensor.shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
tmp = packed_tensor[:, j]
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
tmp = tmp >> self.compress_bits - self.bits
if target_dtype == torch.uint8:
tmp &= mask # remove sign bit
unpacked_tensor[:, index].copy_(tmp.type(target_dtype))
logger.info(f"*****{unpacked_tensor}")
return unpacked_tensor

def pack_tensor(self, raw_tensor):
if "cuda" in self.device:
return self.pack_tensor_with_torch(raw_tensor)
else:
return self.pack_tensor_with_numpy(raw_tensor)

def unpack_tensor(self, packed_tensor):
if "cuda" in self.device:
return self.unpack_tensor_with_torch(packed_tensor)
else:
return self.unpack_tensor_with_numpy(packed_tensor)

def pack(self, int_weight, scale, zp, bias, g_idx=None):
if self.use_optimum_format:
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
self.scales = self.scales.T.contiguous()
self.qweight = self.qweight.T.contiguous()
self.qzeros = self.qzeros.T.contiguous()
int_weight = int_weight.to(self.device)
if self.use_optimum_format and zp is None:
# to avoid overflow
Expand All @@ -350,118 +429,73 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
self.scales = scale.type(self.float_type).to(self.device)
if not self.use_optimum_format and self.compression_dim == 0:
int_weight = int_weight.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
int_weight = int_weight.T.contiguous()
self.qweight = self.qweight.T.contiguous()
origin_shape = int_weight.shape
target_shape = self.qweight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)

# pack weight
for j in range(target_shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = int_weight[:, start:end].type(self.compression_dtype)
for e in range(tmp.shape[1]):
tmp[:, e] &= mask
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qweight[:, j] |= tmp[:, e]
self.qweight.copy_(self.pack_tensor(int_weight))
if not self.use_optimum_format and self.compression_dim == 0:
self.qweight = self.qweight.t_().contiguous()
self.qweight = self.qweight.T.contiguous()

if zp is not None:
zp = zp.to(self.device)
if self.use_optimum_format:
zp -= 1
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
zp = zp.T.contiguous()
self.qzeros = self.qzeros.T.contiguous()
assert hasattr(self, "qzeros"), "zp is not set when initializing."
target_shape = self.qzeros.shape
for j in range(target_shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = zp[:, start:end].type(self.compression_dtype)
for e in range(tmp.shape[1]):
tmp[:, e] &= mask
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qzeros[:, j] |= tmp[:, e]
self.qzeros.copy_(self.pack_tensor(zp))
if self.use_optimum_format or self.compression_dim == 0:
self.qzeros = self.qzeros.t_().contiguous()
self.qzeros = self.qzeros.T.contiguous()
if self.use_optimum_format:
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
self.scales = self.scales.T.contiguous()
self.qweight = self.qweight.T.contiguous()
self.qzeros = self.qzeros.T.contiguous()

def recover(self):
logger.debug(f"Recovering {self} weight")
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight
scales = self.scales.T.contiguous() if self.use_optimum_format else self.scales
qweight = self.qweight.T.contiguous() if self.use_optimum_format else self.qweight

device = scales.device
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
if self.g_idx is None:
# used for recovering fp32_weight
self.g_idx = torch.tensor([i // self.groupsize for i in range(self.in_features)], dtype=torch.int32)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(device)
if hasattr(self, "qzeros"):
weight_dtype = torch.uint8
else:
weight_dtype = torch.int8
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.t_().contiguous()
qweight = qweight.t_().contiguous()
origin_shape = weight.shape
target_shape = qweight.shape
for j in range(target_shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
if index >= origin_shape[1]:
continue
tmp = qweight[:, j]
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
tmp = tmp >> self.compress_bits - self.bits
if weight_dtype == torch.uint8:
tmp &= mask # remove sign bit
weight[:, index] = tmp.type(weight_dtype)
qweight = qweight.T.contiguous()
weight = self.unpack_tensor(qweight)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.t_().contiguous()
weight = weight.T.contiguous()
weight = weight[: self.out_features, : self.in_features] # avoid oversize
if "int" not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
new_weight += torch.where(weight == k, v, 0)
weight = new_weight
# unpack zero_point
if hasattr(self, "qzeros"):
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
qzeros = self.qzeros.T.contiguous() if self.use_optimum_format else self.qzeros
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.t_().contiguous()
qzeros = qzeros.t_().contiguous()
origin_shape = zp.shape
target_shape = qzeros.shape
for j in range(target_shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
if index >= origin_shape[1]:
continue
tmp = qzeros[:, j]
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
tmp = tmp >> self.compress_bits - self.bits
tmp &= mask
zp[:, index] = tmp.type(zp_dtype)
qzeros = qzeros.T.contiguous()
zp = self.unpack_tensor(qzeros)
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.t_().contiguous()
zp = zp.T.contiguous()
zp = zp[: scales.shape[0], : scales.shape[1]] # avoid oversize
if self.use_optimum_format:
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
zp += 1
zp = torch.where(zp > (2**self.bits - 1), 0, zp)
# recover fp32 weight with int_weight, scale, and zero_point
for idx in range(self.in_features):
fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * scales[:, self.g_idx[idx]]
fp32_weight[:, idx] = (torch.subtract(weight[:, idx], zp[:, self.g_idx[idx]]).to(torch.int8)) * scales[
:, self.g_idx[idx]
]
else:
# recover fp32 weight with int_weight, scale
for idx in range(self.in_features):
Expand Down

0 comments on commit daa1431

Please sign in to comment.