Skip to content

Commit

Permalink
fix bug in smoothquant for auto alpha (#1287)
Browse files Browse the repository at this point in the history
(cherry picked from commit 496bd60)
  • Loading branch information
xin3he authored and chensuyue committed Sep 27, 2023
1 parent 424cf3a commit e9c14a5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 45 deletions.
58 changes: 27 additions & 31 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,9 @@ def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0):
if len(output.shape) <= 2:
max_value = torch.max(torch.abs(output))
else:
max_value = torch.max(torch.abs(output.reshape(output.shape[0], -1)), dim=-1).values
output = output.reshape(output.shape[0], -1)
output_q = output_q.reshape(output_q.shape[0], -1)
max_value = torch.max(torch.abs(output), dim=-1).values.unsqueeze(-1)
max_value = torch.clip(max_value, 1e-5)
output = output / max_value ##FIXME need copy not replace
output_q = output_q / max_value
Expand Down Expand Up @@ -712,7 +714,7 @@ def _update_scales_for_auto(self, absorb_scales, weight_scales):
weight_scale = self._reshape_scale_for_weight(layer, weight_scale)
layer.update_scale(input_scale, weight_scale) ##FIXME

def _get_one_sample_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes):
def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes):
self._change_qdq_for_auto(enable=False)

forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output
Expand Down Expand Up @@ -793,15 +795,15 @@ def dict_to_list(dic):
return best_alpha

def _auto_tune_alpha_new(
self, input_maxes, auto_calib_iter=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min"
self, input_maxes, calib_sample_num=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min"
):
"""Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.
This function takes quantization of the former layers into consideration when qdq one layer
Also, it reduces the memory usage at the cost of increasingtuning time
TODO may have compatibility issue when setting folding=True
:param input_maxes:
:param auto_calib_iter:
:param calib_sample_num:
:param alpha_min:
:param alpha_max:
:param alpha_step:
Expand All @@ -828,88 +830,82 @@ def _auto_tune_alpha_new(
self.absorb_to_layer, input_maxes, default_alpha, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {}
cnt = 0
multiply_factor = auto_calib_iter // 4 if auto_calib_iter >= 4 else auto_calib_iter
alpha_update_iter = 0
# multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha
multiply_factor = calib_sample_num // 4 if calib_sample_num >= 4 else calib_sample_num

best_alphas = default_alpha
if not self.dataloader:
self._qdq_model_unwrapper_for_auto()
return best_alphas
try:
for input, label in self.dataloader:
loss_alphas = {}
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
for key in self.absorb_to_layer.keys():
layer_names = self.absorb_to_layer[key]
for layer_name in layer_names:
best_alphas_per_module[layer_name] = best_alphas_per_module[key]

loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
if loss_alphas == {}:
loss_alphas = loss_tmp
else:
for key in loss_alphas.keys():
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
if isinstance(input, list):
input = move_input_to_device(input, self.device)
for inp in input:
cnt += inp.shape[0]
else:
cnt += input.shape[0]

if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
cnt += self.dataloader.batch_size
if cnt // multiply_factor >= 1:
alpha_update_iter += 1
cnt = 0
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}")
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {} ##TODO check need to remove this one
if cnt >= auto_calib_iter:
if cnt >= calib_sample_num:
break
except:
for input in self.dataloader:
loss_alphas = {}
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
for key in self.absorb_to_layer.keys():
layer_names = self.absorb_to_layer[key]
for layer_name in layer_names:
best_alphas_per_module[layer_name] = best_alphas_per_module[key]

loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
if loss_alphas == {}:
loss_alphas = loss_tmp
else:
for key in loss_alphas.keys():
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
if isinstance(input, list):
input = move_input_to_device(input, self.device)
for inp in input:
cnt += inp.shape[0]
else:
cnt += input.shape[0]
cnt += self.dataloader.batch_size
if cnt // multiply_factor >= 1:
alpha_update_iter += 1
cnt = 0

if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}")
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {} ##TODO check need to remove this one
if cnt >= auto_calib_iter:
if cnt >= calib_sample_num:
break

best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"final {key}:{best_alphas[key]}")
logger.info(f"Final alpha {key}:{best_alphas[key]}")
self._qdq_model_unwrapper_for_auto()
logger.info("auto tuning done")
return best_alphas
Expand Down Expand Up @@ -999,7 +995,7 @@ def transform(

if alpha == "auto":
self.alpha_per_layer = self._auto_tune_alpha_new(
input_maxes_abs, auto_calib_iter=32, **auto_alpha_args
input_maxes_abs, calib_sample_num=32, **auto_alpha_args
) ##save the alpha

if alpha == "auto":
Expand Down
32 changes: 18 additions & 14 deletions test/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,19 @@ def __iter__(self):

class LLMCalibDataloader:
def __init__(self):
self.batch_size = 1
self.batch_size = 3

def __iter__(self):
yield torch.ones([1, 3], dtype=torch.long)
for i in range(4):
yield torch.ones([3, 3], dtype=torch.long)


class TestSqDepthwiseConv(unittest.TestCase):
@classmethod
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3, 1, 1))
Expand Down Expand Up @@ -141,7 +142,7 @@ class TestSqConvOpFuseAuto(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3, 1, 1))
Expand Down Expand Up @@ -181,7 +182,7 @@ class TestSqConvOpFuse(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3, 1, 1))
Expand Down Expand Up @@ -386,21 +387,21 @@ class TestSqListInput(unittest.TestCase):
def setUpClass(self):
class ListDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield [torch.rand((1, 3))]

class TupleDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield (torch.rand((1, 3)))

class ListTupleDataLoader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
input1 = torch.rand((1, 3))
Expand Down Expand Up @@ -499,7 +500,7 @@ class TestAlphaAutoLinear(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3))
Expand Down Expand Up @@ -535,7 +536,7 @@ class TestSqLinearOpFuse(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 3))
Expand Down Expand Up @@ -736,6 +737,8 @@ def test_sq_qkv(self):
sq.transform(alpha=0.5, calib_iter=-1, folding=False)
assert isinstance(sq.model.model.decoder.layers[0].self_attn.k_proj, SQLinearWrapper)


class TestExample(unittest.TestCase):
def test_sq_quant(self):
from neural_compressor import PostTrainingQuantConfig, quantization

Expand Down Expand Up @@ -763,10 +766,11 @@ def forward(self, x):

class CalibDataloader:
def __init__(self):
self.batch_size = 1
self.batch_size = 3

def __iter__(self):
yield input_ids
for i in range(4):
yield input_ids

def calib_func(model):
for i in range(10):
Expand Down Expand Up @@ -935,7 +939,7 @@ class TestSqSkipOp(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 4))
Expand Down Expand Up @@ -992,7 +996,7 @@ class TestSqSkipOp_attn(unittest.TestCase):
def setUpClass(self):
class RandDataloader:
def __init__(self):
pass
self.batch_size = 1

def __iter__(self):
yield torch.rand((1, 4))
Expand Down

0 comments on commit e9c14a5

Please sign in to comment.