Skip to content

Commit

Permalink
add use_HF_format for export_compressed_model (#1379)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin He <xin3.he@intel.com>
  • Loading branch information
xin3he authored Nov 17, 2023
1 parent 0a20016 commit 5179da1
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 120 deletions.
20 changes: 15 additions & 5 deletions docs/source/quantization_weight_only.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ To support low memory inference, Neural Compressor implemented WeightOnlyLinear,
| compression_dtype | torch.int32 | Data type for compressed dtype, select from [torch.int8\|16\|32\|64] |
| compression_dim | 1 | 0 means output channel while 1 means input channel |
| scale_dtype | torch.float32 | Data type for scale and bias |
| use_hf_format | False | Whether to use the popular format present on HuggingFace hub |

**Note:** HuggingFace format is quite special, the main differences are as follows:

> 1: Compression Dimension: weight = 1, zero = 0 and both are transposed.
> 2: Zero Point: zero_point-= 1 before compression. zero_point is always required even for sym.
> 3: Group Index: Use the same number for a group instead of recording channel order.

### **User Code Example**
```python
Expand All @@ -119,12 +127,14 @@ conf = PostTrainingQuantConfig(
)
q_model = quantization.fit(model, conf, eval_func=eval_func)
q_model.save("saved_results")
compressed_model = q_model.export_compressed_model(
compression_dtype=torch.int32,
compression_dim=1,
scale_dtype=torch.float16,
)
compressed_model = q_model.export_compressed_model()
torch.save(compressed_model.state_dict(), "compressed_model.pt")
# or
model = Model()
compressed_model = export_compressed_model(
model,
saved_dir="saved_results",
)
```

The saved_results folder contains two files: `best_model.pt` and `qconfig.json`, and the generated q_model is a fake quantized model.
Expand Down
236 changes: 136 additions & 100 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,12 @@ def __init__(
scale_dtype=torch.float32,
compression_dtype=torch.int32,
compression_dim=1,
gptq_perm=False,
g_idx=False,
device="cpu",
use_hf_format=False,
):
super().__init__()
self.use_hf_format = use_hf_format
self.dtype = dtype
if "int" not in self.dtype: # for nf4, fp4
from neural_compressor.adaptor.torch_utils.weight_only import FLOAT_MAPPING, INT_MAPPING
Expand Down Expand Up @@ -249,69 +251,105 @@ def __init__(
assert compression_dim in [0, 1], (
"Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel."
)
self.register_buffer(
"scale",
torch.zeros(
(out_features, math.ceil(in_features / self.groupsize)),
dtype=self.float_type,
).to(device),
)
if compression_dim == 1:
if self.use_hf_format:
self.register_buffer(
"packed_weight",
"scales",
torch.zeros(
(out_features, math.ceil(in_features / self.n_pack)),
(math.ceil(in_features / self.groupsize), out_features),
dtype=self.float_type,
).to(device),
)
self.scales = self.scales.T
self.register_buffer(
"qweight",
torch.zeros(
(math.ceil(in_features / self.n_pack), out_features),
dtype=self.compressed_dtype,
).to(device),
)
if zp:
self.register_buffer(
"packed_zp",
torch.zeros(
(self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)),
dtype=self.compressed_dtype,
).to(device),
)
else:
self.qweight = self.qweight.T
self.register_buffer(
"packed_weight",
"qzeros",
torch.zeros(
(math.ceil(out_features / self.n_pack), in_features),
(math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)),
dtype=self.compressed_dtype,
).to(device),
)
if zp:
self.qzeros = self.qzeros.T
else:
self.register_buffer(
"scales",
torch.zeros(
(out_features, math.ceil(in_features / self.groupsize)),
dtype=self.float_type,
).to(device),
)
if compression_dim == 1:
self.register_buffer(
"packed_zp",
"qweight",
torch.zeros(
(math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.groupsize)),
(out_features, math.ceil(in_features / self.n_pack)),
dtype=self.compressed_dtype,
).to(device),
)
if zp:
self.register_buffer(
"qzeros",
torch.zeros(
(self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)),
dtype=self.compressed_dtype,
).to(device),
)
else:
self.register_buffer(
"qweight",
torch.zeros(
(math.ceil(out_features / self.n_pack), in_features),
dtype=self.compressed_dtype,
).to(device),
)
if zp:
self.register_buffer(
"qzeros",
torch.zeros(
(math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.groupsize)),
dtype=self.compressed_dtype,
).to(device),
)
if g_idx:
self.register_buffer("g_idx", torch.zeros(in_features, dtype=torch.int32).to(device))
else:
self.g_idx = None
if bias:
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
else:
self.bias = None
if gptq_perm:
self.register_buffer("gptq_perm", torch.zeros(in_features, dtype=torch.int32).to(device))
else:
self.gptq_perm = None

def pack(self, int_weight, scale, zp, bias, gptq_perm=None):
def pack(self, int_weight, scale, zp, bias, g_idx=None):
int_weight = int_weight.to(self.device)
if self.use_hf_format and zp is None:
# to avoid overflow
int_weight = int_weight.type(torch.int32)
shift_bias = 2 ** (self.bits - 1)
int_weight += shift_bias
zp = torch.zeros_like(scale, dtype=torch.uint8) + shift_bias
if bias is not None:
assert hasattr(self, "bias"), "bias is not set when initializing."
self.bias = bias.type(self.float_type).to(self.device)
if gptq_perm is not None:
assert hasattr(self, "gptq_perm"), "gptq_perm is not set when initializing."
self.gptq_perm = gptq_perm.type(torch.int32).to(self.device)
assert scale.shape == self.scale.shape, "Scale shape is mismatched."
self.scale = scale.type(self.float_type).to(self.device)
if self.compression_dim == 0:
if g_idx is not None:
assert hasattr(self, "g_idx"), "g_idx is not set when initializing."
self.g_idx = g_idx.type(torch.int32).to(self.device)
if self.use_hf_format:
invperm = torch.argsort(self.g_idx)
self.g_idx = invperm // self.groupsize
self.g_idx = self.g_idx.type(torch.int32).to(self.device)
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
self.scales = scale.type(self.float_type).to(self.device)
if not self.use_hf_format and self.compression_dim == 0:
int_weight = int_weight.T
self.packed_weight = self.packed_weight.T
self.qweight = self.qweight.T
origin_shape = int_weight.shape
target_shape = self.packed_weight.shape
target_shape = self.qweight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype).to(self.device)

Expand All @@ -323,121 +361,112 @@ def pack(self, int_weight, scale, zp, bias, gptq_perm=None):
for e in range(tmp.shape[1]):
tmp[:, e] &= mask
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.packed_weight[:, j] |= tmp[:, e]
if self.compression_dim == 0:
self.packed_weight = self.packed_weight.T
self.qweight[:, j] |= tmp[:, e]
if not self.use_hf_format and self.compression_dim == 0:
self.qweight = self.qweight.T

if zp is not None:
zp = zp.to(self.device)
if self.compression_dim == 0:
if self.use_hf_format:
zp -= 1
if self.use_hf_format or self.compression_dim == 0:
zp = zp.T
self.packed_zp = self.packed_zp.T
assert hasattr(self, "packed_zp"), "zp is not set when initializing."
target_shape = self.packed_zp.shape
self.qzeros = self.qzeros.T
assert hasattr(self, "qzeros"), "zp is not set when initializing."
target_shape = self.qzeros.shape
for j in range(target_shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = zp[:, start:end].type(self.compressed_dtype)
for e in range(tmp.shape[1]):
tmp[:, e] &= mask
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.packed_zp[:, j] |= tmp[:, e]
if self.compression_dim == 0:
self.packed_zp = self.packed_zp.T
self.qzeros[:, j] |= tmp[:, e]
if self.use_hf_format or self.compression_dim == 0:
self.qzeros = self.qzeros.T
if self.use_hf_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.g_idx = self.g_idx
self.qzeros = self.qzeros.T

def recover(self):
logger.debug(f"Recovering {self} weight")
device = self.scale.device
if self.use_hf_format:
# Prevent broken id links of self.scales and self.scales
self.scales = self.scales.T
self.qweight = self.qweight.T
self.g_idx = self.g_idx
self.qzeros = self.qzeros.T
device = self.scales.device
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
if self.g_idx is None:
# used for recovering fp32_weight
self.g_idx = torch.tensor([i // self.groupsize for i in range(self.in_features)], dtype=torch.int32)
mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype).to(device)
if hasattr(self, "packed_zp"):
if hasattr(self, "qzeros"):
weight_dtype = torch.uint8
else:
weight_dtype = torch.int8
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
packed_weight = self.packed_weight
if self.compression_dim == 0:
qweight = self.qweight
if not self.use_hf_format and self.compression_dim == 0:
weight = weight.T
packed_weight = packed_weight.T
qweight = qweight.T
origin_shape = weight.shape
target_shape = packed_weight.shape
target_shape = qweight.shape
for j in range(target_shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
if index >= origin_shape[1]:
continue
tmp = packed_weight[:, j]
tmp = qweight[:, j]
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
tmp = tmp >> self.compress_bits - self.bits
if weight_dtype == torch.uint8:
tmp &= mask # remove sign bit
weight[:, index] = tmp.type(weight_dtype)
if self.compression_dim == 0:
if not self.use_hf_format and self.compression_dim == 0:
weight = weight.T
if "int" not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
new_weight += torch.where(weight == k, v, 0)
weight = new_weight
# unpack zero_point
if hasattr(self, "packed_zp"):
if hasattr(self, "qzeros"):
zp_dtype = self.compressed_dtype # to avoid overflow when weight-zp
zp = torch.zeros(self.scale.shape, dtype=zp_dtype).to(device)
packed_zp = self.packed_zp
if self.compression_dim == 0:
zp = torch.zeros(self.scales.shape, dtype=zp_dtype).to(device)
qzeros = self.qzeros
if self.use_hf_format or self.compression_dim == 0:
zp = zp.T
packed_zp = packed_zp.T
qzeros = qzeros.T
origin_shape = zp.shape
target_shape = packed_zp.shape
target_shape = qzeros.shape
for j in range(target_shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
if index >= origin_shape[1]:
continue
tmp = packed_zp[:, j]
tmp = qzeros[:, j]
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
tmp = tmp >> self.compress_bits - self.bits
tmp &= mask
zp[:, index] = tmp.type(zp_dtype)
if self.compression_dim == 0:
if self.use_hf_format or self.compression_dim == 0:
zp = zp.T
if self.use_hf_format:
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
zp += 1
zp = torch.where(zp > (2**self.bits - 1), 0, zp)
# recover fp32 weight with int_weight, scale, and zero_point
left_element = self.in_features % self.groupsize
if left_element != 0:
split_index = self.in_features // self.groupsize * self.groupsize
weight1 = weight[:, :-split_index].reshape(-1, self.groupsize)
scale1 = self.scale[:, :-1].reshape(-1, 1)
zp1 = zp[:, :-1].reshape(-1, 1)
weight1 = ((weight1 - zp1) * scale1).reshape(self.out_features, -1)
weight2 = weight[:, -split_index:]
scale2 = self.scale[:, -1:]
zp2 = zp[:, -1].reshape(-1, 1)
weight2 = (weight2 - zp2) * scale2
fp32_weight = torch.cat((weight1, weight2), dim=1)
else:
weight = weight.reshape(-1, self.groupsize)
scale = self.scale.reshape(-1, 1)
zp = zp.reshape(-1, 1)
fp32_weight = ((weight - zp) * scale).reshape(self.out_features, -1)
for idx in range(self.in_features):
fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * self.scales[:, self.g_idx[idx]]
else:
# recover fp32 weight with int_weight, scale
left_element = self.in_features % self.groupsize
if left_element != 0:
split_index = self.in_features // self.groupsize * self.groupsize
weight1 = weight[:, :split_index].reshape(-1, self.groupsize)
scale1 = self.scale[:, :-1].reshape(-1, 1)
weight1 = (weight1 * scale1).reshape(self.out_features, -1)
weight2 = weight[:, split_index:]
scale2 = self.scale[:, -1:]
weight2 = weight2 * scale2
fp32_weight = torch.cat((weight1, weight2), dim=1)
else:
weight = weight.reshape(-1, self.groupsize)
scale = self.scale.reshape(-1, 1)
fp32_weight = (weight * scale).reshape(self.out_features, -1)
if self.gptq_perm is not None:
invperm = torch.argsort(self.gptq_perm)
fp32_weight = fp32_weight[:, invperm]
for idx in range(self.in_features):
fp32_weight[:, idx] = weight[:, idx] * self.scales[:, self.g_idx[idx]]
return fp32_weight

def forward(self, input):
Expand All @@ -453,9 +482,16 @@ def forward(self, input):
return F.linear(input, weight, self.bias)

def extra_repr(self) -> str:
return "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format(
self.in_features, self.out_features, self.bits, self.groupsize, self.bias is not None
tmp_str = "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format(
self.in_features,
self.out_features,
self.bits,
self.groupsize,
self.bias is not None,
)
if self.use_hf_format:
tmp_str += ", use_hf_format=True"
return tmp_str


class FakeAffineTensorQuantFunction(Function):
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def rtn_quantize(
compression_dim = kwargs.get("compression_dim", 1)
scale_dtype = kwargs.get("scale_dtype", torch.float32)
device = kwargs.get("device", "cpu")
use_hf_format = kwargs.get("use_hf_format", False)
for name, m in model.named_modules():
if m.__class__.__name__ not in supported_layers:
continue
Expand Down Expand Up @@ -451,6 +452,7 @@ def rtn_quantize(
compression_dim=compression_dim,
scale_dtype=scale_dtype,
device=device,
use_hf_format=use_hf_format,
)
new_module.pack(int_weight, scale, zp, m.bias)
if name == "":
Expand Down
Loading

0 comments on commit 5179da1

Please sign in to comment.