Skip to content

Commit 9600e1d

Browse files
authored
Narrow the tuning space of sq auto-tune (#1489)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent 6c78dfe commit 9600e1d

File tree

4 files changed

+24
-16
lines changed

4 files changed

+24
-16
lines changed

docs/source/tuning_strategies.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ flowchart TD
179179
180180
> `*` INC will detect the block pattern for [transformer-like](https://arxiv.org/abs/1706.03762) model by default.
181181
182-
> For [smooth quantization](./smooth_quant.md), users can tune the smooth quantization alpha by providing a list of scalars for the `alpha` item. The tuning process will take place at the **start stage** of the tuning procedure. For details usage, please refer to the [smooth quantization example](./smooth_quant.md#Example).
182+
> For [smooth quantization](./smooth_quant.md), users can tune the smooth quantization alpha by providing a list of scalars for the `alpha` item. For details usage, please refer to the [smooth quantization example](./smooth_quant.md#Usage).
183183
184184
> For [weight-only quantization](./quantization_weight_only.md), users can tune the weight-only algorithms from the available [pre-defined configurations](./quantization_weight_only.md#woq-algorithms-tuning). The tuning process will take place at the **start stage** of the tuning procedure, preceding the smooth quantization alpha tuning. For details usage, please refer to the [weight-only quantization example](./quantization_weight_only.md#woq-algorithms-tuning).
185185
*Please note that this behavior is specific to the `ONNX Runtime` backend.*

neural_compressor/strategy/strategy.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(
186186
# track tuning cfg with the current best accuracy
187187
self.cur_best_tuning_cfg = {}
188188
self.re_quant = False
189+
self.early_stop_sq_tuning_process = False
189190

190191
self._trials_count = 0
191192
self._capability = None
@@ -1152,6 +1153,9 @@ def _should_tuning_sq_alpha(self, recipes):
11521153
def tuning_sq_alpha(self, tuning_space, tuning_cfg, recipes):
11531154
"""Tuning smooth quant's alpha.
11541155
1156+
After trying all alpha values, the sq tuning process will stop early, returning the current best qmodel,
1157+
even if the current best accuracy does not meet the accuracy criterion.
1158+
11551159
Args:
11561160
tuning_space: tuning space
11571161
tuning_cfg: the initial tuning config
@@ -1166,8 +1170,12 @@ def tuning_sq_alpha(self, tuning_space, tuning_cfg, recipes):
11661170
), "Only tune the smooth quant's alpha when user provide the alpha list,\
11671171
but got alpha_list: {alpha_list}"
11681172
logger.info("[STRATEGY] Start tuning smooth quant'alpha.")
1173+
number_of_alpha = len(sq_alpha_list)
1174+
sq_trials_cnt = 0
11691175
sq_sampler = tuning_sampler_dict.get_class("smooth_quant")(tuning_space, [], tuning_cfg, sq_alpha_list)
11701176
for tune_cfg in sq_sampler:
1177+
sq_trials_cnt += 1
1178+
self.early_stop_sq_tuning_process = sq_trials_cnt == number_of_alpha
11711179
yield tune_cfg
11721180

11731181
def _should_tuning_woq_algo(self):
@@ -1961,6 +1969,16 @@ def stop(self, timeout, trials_count):
19611969
need_stop = True
19621970
else:
19631971
need_stop = False
1972+
if not need_stop and self.early_stop_sq_tuning_process:
1973+
if self.best_tuning_cfg is None:
1974+
self.best_tuning_cfg = self._tune_cfg_converter(self.cur_best_tuning_cfg)
1975+
logger.info(
1976+
"[Strategy] Tried all alpha values but none met the accuracy criterion. "
1977+
"The tuning process was early stopped and "
1978+
f"the currently best model(accuracy: {self.cur_best_acc}) was returned."
1979+
)
1980+
1981+
need_stop = True
19641982

19651983
return need_stop
19661984

test/algorithm/test_smooth_quant.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,8 @@ def _test_sq_tune_alpha_common(self, eval_func, alpha=np.arange(0.1, 0.2, 0.05).
11501150
from neural_compressor import quantization
11511151
from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion
11521152

1153+
logger.info(f"alpha is: {alpha}")
1154+
11531155
tuning_criterion = TuningCriterion(max_trials=8)
11541156

11551157
fp32_model = DemoModel()
@@ -1183,8 +1185,8 @@ def fake_eval(model, eval_result_lst):
11831185
# test for alpha is a list
11841186
for eval_result_lst, note in [
11851187
([1, 0.8, 1.1, 0.7, 1.1], "Expect tuning ends at 2nd trial with alpha is 0.15"),
1186-
([1, 0.8, 0.9, 0.7, 1.1], "Expect tuning ends at 4th trial with alpha is 0.15"),
1187-
([1, 0.9, 0.8, 0.7, 1.1], "Expect tuning ends at 4th trial with alpha is 0.10"),
1188+
([1, 0.8, 0.9, 0.7, 1.1], "Expect tuning ends at 2nd trial with alpha is 0.15"),
1189+
([1, 0.9, 0.8, 0.7, 1.1], "Expect tuning ends at 1st trial with alpha is 0.10"),
11881190
]:
11891191
logger.info(f"test_sq_tune_alpha_common with eval_result_lst: {eval_result_lst}")
11901192
logger.info(note)
@@ -1222,13 +1224,7 @@ def fake_eval(model, eval_result_lst):
12221224
[1, 0.8, 0.9, 0.7, 1.1],
12231225
np.arange(0.1, 0.2, 0.05).tolist(),
12241226
"auto",
1225-
"Expect tuning ends at 4th trial with alpha is 0.15 at basic strategy.",
1226-
),
1227-
(
1228-
[1, 1.1, 0.8, 0.7, 1.1],
1229-
np.arange(0.1, 0.2, 0.05).tolist(),
1230-
0,
1231-
"Expect tuning ends at 1th trial with alpha is 0.1",
1227+
"Expect tuning ends at 2th trial with alpha is 0.15 at basic strategy.",
12321228
),
12331229
]:
12341230
logger.info("test_sq_tune_alpha_common with ")

test/algorithm/test_smooth_quant_onnx.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,6 @@ def fake_eval(model, eval_result_lst):
279279
"auto",
280280
"Expect tuning ends at 4th trial with alpha is 0.15 at basic strategy.",
281281
),
282-
(
283-
[1, 1.1, 0.8, 0.7, 1.1],
284-
np.arange(0.1, 0.2, 0.05).tolist(),
285-
0,
286-
"Expect tuning ends at 1th trial with alpha is 0.1",
287-
),
288282
]:
289283
logger.info("test_sq_tune_alpha_common with ")
290284
logger.info(f"eval_result_lst: {eval_result_lst}, alpha: {alpha}, quant_level: {quant_level}")

0 commit comments

Comments
 (0)