2020import math
2121import os
2222import re
23- from collections import OrderedDict , UserDict
23+ from collections import OrderedDict , UserDict , namedtuple
2424from functools import partial
2525
2626import yaml
@@ -1800,7 +1800,7 @@ def smooth_quant(
18001800 assert folding , "IPEX version >= 2.1 is required for SmoothQuant folding=False."
18011801
18021802 if not hasattr (self , "sq" ) or force_re_smooth :
1803- from neural_compressor . adaptor . torch_utils .waq import TorchSmoothQuant
1803+ from . torch_utils .smooth_quant import TorchSmoothQuant
18041804
18051805 self .sq = TorchSmoothQuant (
18061806 model ._model , dataloader = dataloader , example_inputs = self .example_inputs , q_func = self .q_func
@@ -1813,18 +1813,17 @@ def smooth_quant(
18131813 kwargs ["percentile" ] = percentile
18141814 if scales_per_op is not None :
18151815 kwargs ["scales_per_op" ] = scales_per_op
1816- auto_alpha_args ["init_alpha" ] = default_alpha
18171816 model ._model = self .sq .transform (
18181817 alpha = alpha ,
18191818 folding = folding ,
18201819 calib_iter = calib_iter ,
18211820 weight_clip = weight_clip ,
1821+ default_alpha = default_alpha ,
18221822 auto_alpha_args = auto_alpha_args ,
18231823 ** kwargs ,
18241824 )
18251825 if self .sq .record_max_info :
18261826 model .sq_max_info = self .sq .max_value_info
1827- model .sq_scale_info = self .sq .sq_scale_info
18281827 return model
18291828
18301829 def _apply_pre_optimization (self , model , tune_cfg , recover = False ):
@@ -1841,7 +1840,7 @@ def _apply_pre_optimization(self, model, tune_cfg, recover=False):
18411840 q_model = model ._model
18421841 sq_max_info = model .sq_max_info
18431842 if sq_max_info :
1844- from neural_compressor . adaptor . torch_utils .waq import TorchSmoothQuant
1843+ from . torch_utils .smooth_quant import TorchSmoothQuant
18451844
18461845 tsq = TorchSmoothQuant (q_model , None )
18471846 alpha = tune_cfg ["recipe_cfgs" ]["smooth_quant_args" ]["alpha" ]
@@ -1877,9 +1876,8 @@ def qdq_quantize(self, model, tune_cfg):
18771876 model: qdq quantized model.
18781877 """
18791878 q_model = model ._model
1880- from neural_compressor .adaptor .torch_utils .waq import get_module , set_module
1881-
18821879 from .torch_utils .model_wrapper import QDQLinear , SQLinearWrapper
1880+ from .torch_utils .smooth_quant import get_module , set_module
18831881
18841882 smoothquant_scale_info = {}
18851883 fallback_op_name_list = []
@@ -3319,7 +3317,37 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
33193317 inplace = True if self .performance_only else False
33203318
33213319 # fetch SmoothQuant scale info from pre-optimized model
3322- smoothquant_scale_info = model .sq_scale_info
3320+ sq_max_info = model .sq_max_info
3321+ if sq_max_info :
3322+ smoothquant_scale_info = {}
3323+ from .torch_utils .model_wrapper import SQLinearWrapper
3324+ from .torch_utils .smooth_quant import get_module
3325+
3326+ for _ , info in sq_max_info .items ():
3327+ alpha = info ["alpha" ]
3328+ absorbed_layer = info ["absorbed_layer" ]
3329+ input_minmax = info ["input_minmax" ]
3330+ # for peft model,lora_B weights is 0.
3331+ weight_max = info ["weight_max" ]
3332+ if self .sq .weight_clip :
3333+ weight_max = weight_max .clamp (min = 1e-5 )
3334+ abs_input_max = torch .max (torch .abs (input_minmax [0 ]), torch .abs (input_minmax [1 ]))
3335+ input_power = torch .pow (abs_input_max , alpha )
3336+ weight_power = torch .pow (weight_max , 1 - alpha )
3337+ scale = torch .clip (input_power / weight_power , min = 1e-5 )
3338+ for op_name in absorbed_layer :
3339+ module = copy .deepcopy (get_module (q_model ._model , op_name ))
3340+ new_module = SQLinearWrapper (module , 1.0 / scale , input_minmax , alpha )
3341+ weight_scale = new_module ._get_weight_scale ()
3342+ smoothquant_scale_info [op_name ] = {
3343+ "alpha" : new_module .alpha ,
3344+ "input_scale_for_mul" : new_module .input_scale ,
3345+ "input_scale_after_mul" : new_module .scale ,
3346+ "input_zero_point_after_mul" : new_module .zero_point ,
3347+ "input_dtype" : new_module .dtype ,
3348+ "weight_scale_after_mul" : weight_scale ,
3349+ }
3350+ logger .debug (f"Current SmoothQuant alpha of { op_name } is { alpha } " )
33233351
33243352 # Check save_qconf_summary part is a workaround for IPEX bug.
33253353 # Sometimes the prepared model from get_op_capablitiy loss this attribute
@@ -4767,7 +4795,7 @@ def teq_quantize(self, model, tune_cfg, dataloader, calib_func):
47674795
47684796 supported_layers = ["Linear" ]
47694797 if folding : # pragma: no cover
4770- from neural_compressor . adaptor . torch_utils .waq import GraphTrace
4798+ from . torch_utils .smooth_quant import GraphTrace
47714799
47724800 tg = GraphTrace ()
47734801 absorb_to_layer , _ = tg .get_absorb_to_layer (model , self .example_inputs , supported_layers )
0 commit comments