Skip to content

Commit

Permalink
change threshold for ptq hpo (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 authored Jul 5, 2022
1 parent c590123 commit 8fe111b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
4 changes: 2 additions & 2 deletions paddleslim/auto_compression/auto_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@
MAGIC_SPARSE_RATIO = 0.75
### TODO: 0.02 threshold maybe not suitable, need to check
### NOTE: reduce magic data to choose quantization aware training.
MAGIC_MAX_EMD_DISTANCE = 0.0002 #0.02
MAGIC_MIN_EMD_DISTANCE = 0.0001 #0.01
MAGIC_MAX_EMD_DISTANCE = 0.00002 #0.02
MAGIC_MIN_EMD_DISTANCE = 0.00001 #0.01

DEFAULT_TRANSFORMER_STRATEGY = 'prune_0.25_int8'
DEFAULT_STRATEGY = 'origin_int8'
Expand Down
18 changes: 10 additions & 8 deletions paddleslim/quant/post_quant_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ def standardization(data):
"""standardization numpy array"""
mu = np.mean(data, axis=0)
sigma = np.std(data, axis=0)
sigma = 1e-13 if sigma == 0. else sigma
if isinstance(sigma, list) or isinstance(sigma, np.ndarray):
for idx, sig in enumerate(sigma):
if sig == 0.:
sigma[idx] = 1e-13
else:
sigma = 1e-13 if sigma == 0. else sigma
return (data - mu) / sigma


Expand Down Expand Up @@ -241,18 +246,15 @@ def eval_quant_model():
if have_invalid_num(out_float) or have_invalid_num(out_quant):
continue

try:
out_float = standardization(out_float)
out_quant = standardization(out_quant)
except:
continue
out_float_list.append(out_float)
out_quant_list.append(out_quant)
out_float_list.append(list(out_float))
out_quant_list.append(list(out_quant))
valid_data_num += 1

if valid_data_num >= max_eval_data_num:
break

out_float_list = standardization(out_float_list)
out_quant_list = standardization(out_quant_list)
emd_sum = cal_emd_lose(out_float_list, out_quant_list,
out_len_sum / float(valid_data_num))
_logger.info("output diff: {}".format(emd_sum))
Expand Down

0 comments on commit 8fe111b

Please sign in to comment.