Skip to content

Commit

Permalink
Quantconfig add mse field (#825)
Browse files Browse the repository at this point in the history
* Quantconfig add mse field

* code review

* ruff format

* code opt

* set Quantizer.configure mse=0.0

* code review
  • Loading branch information
CL-ModelCloud authored Dec 12, 2024
1 parent 6e46450 commit c6bbb02
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 8 deletions.
2 changes: 1 addition & 1 deletion gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from huggingface_hub import list_repo_files
from transformers import AutoConfig

from ._const import get_best_device
from ..utils import BACKEND, EVAL
from ..utils.logger import setup_logger
from ..utils.model import check_and_get_model_type
from ._const import get_best_device
from .base import BaseGPTQModel, QuantizeConfig
from .definitions.baichuan import BaiChuanGPTQ
from .definitions.bloom import BloomGPTQ
Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def store_input_hook(_, args, kwargs):
for name in subset:
bits = self.quantize_config.bits
sym = self.quantize_config.sym
mse = self.quantize_config.mse
if self.quantize_config.dynamic is not None:
layer_name = f"{self.layers_node}.{i}.{name}"

Expand All @@ -526,7 +527,7 @@ def store_input_hook(_, args, kwargs):
bits,
perchannel=True,
sym=sym,
mse=False,
mse=mse,
)

for name in skipped_modules:
Expand Down
9 changes: 8 additions & 1 deletion gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers

from ..quantization.config import (FORMAT, META_FIELD_DAMP_AUTO_INCREMENT, META_FIELD_DAMP_PERCENT,
from ..quantization.config import (FORMAT, META_FIELD_DAMP_AUTO_INCREMENT, META_FIELD_DAMP_PERCENT, META_FIELD_MSE,
META_FIELD_QUANTIZER, META_FIELD_STATIC_GROUPS, META_FIELD_TRUE_SEQUENTIAL,
META_FIELD_URI, META_QUANTIZER_GPTQMODEL, META_VALUE_URI, MIN_VERSION_WITH_V2)
from ..utils.backend import BACKEND
Expand Down Expand Up @@ -106,6 +106,13 @@ def save_quantized(
key=META_FIELD_TRUE_SEQUENTIAL,
value=self.quantize_config.true_sequential
)

self.quantize_config.meta_set(
key=META_FIELD_MSE,
value=self.quantize_config.mse
)


# The config, quantize_config and model may be edited in place in save_quantized.
config = copy.deepcopy(self.model.config)
quantize_config = copy.deepcopy(self.quantize_config)
Expand Down
4 changes: 4 additions & 0 deletions gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
META_FIELD_STATIC_GROUPS = "static_groups"
META_FIELD_TRUE_SEQUENTIAL = "true_sequential"

META_FIELD_MSE = "mse"

# pkg names
PKG_AUTO_ROUND = "auto-round"

Expand Down Expand Up @@ -116,6 +118,8 @@ class QuantizeConfig():
# if you inference with gptqmodel, save to gptq_v2 format for best result
format: FORMAT = field(default=FORMAT.GPTQ)

mse: float = field(default=0.0)

# parallel packing will make ~40% speedup for many models, but may cause OOM in some large models
# if OOM, can set to False
parallel_packing: bool = field(default=True)
Expand Down
8 changes: 3 additions & 5 deletions gptqmodel/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def configure(
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
mse=0.0, # 2.4
grid=100,
maxshrink=0.8,
trits=False,
Expand All @@ -38,7 +37,6 @@ def configure(
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
if trits:
Expand Down Expand Up @@ -86,7 +84,7 @@ def find_params(self, x, weight=False):
else:
self.zero = torch.round(-xmin / self.scale)

if self.mse:
if self.mse > 0.0:
best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
Expand All @@ -97,7 +95,7 @@ def find_params(self, x, weight=False):
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
q.pow_(self.mse)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
Expand Down

0 comments on commit c6bbb02

Please sign in to comment.