Skip to content

Commit

Permalink
[AutoScheduler] Support early_stopping per task (#7377)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Support early_stopping per task

* address comment

* fix test

* Update python/tvm/auto_scheduler/task_scheduler.py

* Update python/tvm/auto_scheduler/task_scheduler.py

* trigger ci

* trigger ci
  • Loading branch information
comaniac authored Feb 5, 2021
1 parent 38c9eb1 commit d8313d0
Showing 1 changed file with 36 additions and 11 deletions.
47 changes: 36 additions & 11 deletions python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def __init__(
# task_cts[i] saves how many times task i is tuned
self.task_cts = [0 for _ in range(len(self.tasks))]

# task_best_cts[i] saves the round task i found the best latency
self.task_best_cts = [0 for _ in range(len(self.tasks))]

# task_costs_history[i] saves the latency history of task i
self.task_costs_history = [[] for _ in range(len(self.tasks))]

Expand Down Expand Up @@ -281,13 +284,14 @@ def tune(
search_policy="default",
search_policy_params=None,
adapative_training=False,
per_task_early_stopping=None,
):
"""Tune a batch of tasks together.
Parameters
----------
tune_option: TuningOptions
The options of tuning
The tuning options applied to all tasks.
search_policy: : Union[str, List[SearchPolicy]] = "default"
The list of search policies.
If it is str,
Expand All @@ -299,10 +303,17 @@ def tune(
adapative_training : bool = False
Option used by XGBModel to reduce the model training frequency when there're
too many logs.
per_task_early_stopping : Optional[int]
Stop tuning a task early if getting no improvement after n measurements.
"""
# init members
self.tune_option = tune_option
early_stopping = 1e20 if tune_option.early_stopping < 0 else tune_option.early_stopping
self.early_stopping_all = (
1e20 if tune_option.early_stopping < 0 else tune_option.early_stopping
)
self.early_stopping_task = (
1e20 if per_task_early_stopping is None else per_task_early_stopping
)

self.measurer = ProgramMeasurer(
tune_option.builder,
Expand Down Expand Up @@ -417,13 +428,13 @@ def tune(
if self.cur_score < self.best_score:
self.best_score = self.cur_score
self.best_ct = self.ct
elif self.ct - self.best_ct >= early_stopping and all(
elif self.ct - self.best_ct >= self.early_stopping_all and all(
cost < 1e9 for cost in self.best_costs
):
if self.tune_option.verbose >= 1:
print(
"Stop early since no performance improvement in the last "
+ str(early_stopping)
+ str(self.early_stopping_all)
+ " measurement trials."
)
break
Expand All @@ -439,15 +450,22 @@ def _tune_task(self, task_idx):
self.num_measures_per_round, self.measurer
)

self.task_cts[task_idx] += 1

for res in measure_results:
cost = array_mean(res.costs)
if cost < self.best_costs[task_idx]:
self.task_best_cts[task_idx] = self.task_cts[task_idx]
self.best_costs[task_idx] = cost

if len(measure_inputs) == 0:
# Stop tuning this task in the rest of the process if its search space has been
# fully explored or it has no improvement for a long while.
no_change_trials = (
self.task_cts[task_idx] - self.task_best_cts[task_idx]
) * self.num_measures_per_round
if len(measure_inputs) == 0 or no_change_trials > self.early_stopping_task:
self.dead_tasks.add(task_idx)

self.task_cts[task_idx] += 1
self.task_costs_history[task_idx].append(self.best_costs[task_idx])

self.ct += len(measure_inputs)
Expand Down Expand Up @@ -494,17 +512,24 @@ def _restore_status(self, log_file, num_measures_per_round):
if task_idx is None:
continue

self.task_cts[task_idx] += 1

if res.error_no == 0:
self.best_costs[task_idx] = min(self.best_costs[task_idx], array_mean(res.costs))
cost = array_mean(res.costs)
if self.best_costs[task_idx] < cost:
self.best_costs[task_idx] = cost
self.task_best_cts = self.task_cts[task_idx]

self.task_cts[task_idx] += 1
for idx in range(len(self.tasks)):
if self.task_cts[idx] - self.task_best_cts[idx] > self.early_stopping_task:
self.dead_tasks.add(idx)

for i in range(len(self.tasks)):
# The computation of taks_cts is just an estimation.
# The estimation may not be accurate if the log file is changed externally or
# `num_measures_per_round` is different from the last tuning.
self.task_cts[i] = int(self.task_cts[i] / num_measures_per_round + 0.5)
self.task_costs_history[i].append(self.best_costs[i])
self.task_cts[idx] = int(self.task_cts[idx] / num_measures_per_round + 0.5)
self.task_best_cts[idx] = int(self.task_best_cts[idx] / num_measures_per_round + 0.5)
self.task_costs_history[idx].append(self.best_costs[idx])

self.cur_score = self._compute_score(self.best_costs)

Expand Down

0 comments on commit d8313d0

Please sign in to comment.