Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reap completion_criterion method #3092

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 49 additions & 54 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def run_n_trials(
Args:
max_trials: Maximum number of trials to run.
ignore_global_stopping_strategy: If set, Scheduler will skip the global
stopping strategy in completion_criterion.
stopping strategy in should_consider_optimization_complete.
timeout_hours: Limit on length of ths optimization; if reached, the
optimization will abort even if completon criterion is not yet reached.
idle_callback: Callable that takes a Scheduler instance as an argument to
Expand Down Expand Up @@ -595,9 +595,10 @@ def run_all_trials(
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
idle_callback: Callable[[Scheduler], Any] | None = None,
) -> OptimizationResult:
"""Run all trials until ``completion_criterion`` is reached (by default,
completion criterion is reaching the ``num_trials`` setting, passed to
scheduler on instantiation as part of ``SchedulerOptions``).
"""Run all trials until ``should_consider_optimization_complete`` yields
true (by default, should_consider_optimization_complete will yield true when
reaching the ``num_trials`` setting, passed to scheduler on instantiation as
part of ``SchedulerOptions``).

NOTE: This function is available only when ``SchedulerOptions.num_trials`` is
specified.
Expand Down Expand Up @@ -712,7 +713,7 @@ def run_trials_and_yield_results(
a completion signal is received from the generation strategy, or
``max_trials`` trials have been run (whichever happens first).
ignore_global_stopping_strategy: If set, Scheduler will skip the global
stopping strategy in completion_criterion.
stopping strategy in should_consider_optimization_complete.
timeout_hours: Maximum number of hours, for which
to run the optimization. This function will abort after running
for `timeout_hours` even if stopping criterion has not been reached.
Expand Down Expand Up @@ -960,51 +961,17 @@ def should_consider_optimization_complete(self) -> tuple[bool, str]:
run more trials (and conclude the optimization via ``_complete_optimization``).

NOTE: An optimization is considered complete when a generation strategy signaled
completion or when the ``completion_criterion`` on this scheduler
evaluates to ``True``. The ``completion_criterion`` method is also responsible
for checking global_stopping_strategy's decision as well. Alongside the stop
decision, this function returns a string describing the reason for stopping
the optimization.
completion or when the ``should_consider_optimization_complete`` method on this
scheduler evaluates to ``True``. The ``should_consider_optimization_complete``
method is also responsible for checking global_stopping_strategy's decision as
well. Alongside the stop decision, this function returns a string describing the
reason for stopping the optimization.
"""
if self._optimization_complete:
return True, ""

should_complete, completion_message = self.completion_criterion()
if should_complete:
self.logger.info(f"Completing the optimization: {completion_message}.")
return should_complete, completion_message

def should_abort_optimization(self, timeout_hours: float | None = None) -> bool:
"""Checks whether this scheduler has reached some intertuption / abort
criterion, such as an overall optimization timeout, tolerated failure rate, etc.
"""
# if failure rate is exceeded, raise an exception.
# this check should precede others to ensure it is not skipped.
self.error_if_failure_rate_exceeded()

# if optimization is timed out, return True, else return False
timed_out = (
timeout_hours is not None
and self._latest_optimization_start_timestamp is not None
and current_timestamp_in_millis()
- none_throws(self._latest_optimization_start_timestamp)
>= none_throws(timeout_hours) * 60 * 60 * 1000
)
if timed_out:
self.logger.error(
"Optimization timed out (timeout hours: " f"{timeout_hours})!"
)
return timed_out

def completion_criterion(self) -> tuple[bool, str]:
"""Optional stopping criterion for optimization, which checks whether
``total_trials`` trials have been run or the ``global_stopping_strategy``
suggests stopping the optimization.
should_stop, message = False, ""

Returns:
A boolean representing whether the optimization should be stopped,
and a string describing the reason for stopping.
"""
if (
not self.__ignore_global_stopping_strategy
and self.options.global_stopping_strategy is not None
Expand All @@ -1022,18 +989,50 @@ def completion_criterion(self) -> tuple[bool, str]:
experiment=self.experiment
)
if stop_optimization:
return True, global_stopping_msg
should_stop = True
message = global_stopping_msg

if self.options.total_trials is None:
elif self.options.total_trials is None:
# We validate that `total_trials` is set in `run_all_trials`,
# so it will not run indefinitely.
return False, ""

num_trials = len(self.trials)
should_stop = num_trials >= none_throws(self.options.total_trials)
message = "Exceeding the total number of trials." if should_stop else ""
else:
num_trials = len(self.trials)
should_stop = num_trials >= none_throws(self.options.total_trials)
message = "Exceeding the total number of trials." if should_stop else ""

if should_stop:
self.logger.info(
f"Completing the optimization: {message}. "
f"`should_consider_optimization_complete` "
f"is `True`, not running more trials."
)

return should_stop, message

def should_abort_optimization(self, timeout_hours: float | None = None) -> bool:
"""Checks whether this scheduler has reached some intertuption / abort
criterion, such as an overall optimization timeout, tolerated failure rate, etc.
"""
# if failure rate is exceeded, raise an exception.
# this check should precede others to ensure it is not skipped.
self.error_if_failure_rate_exceeded()

# if optimization is timed out, return True, else return False
timed_out = (
timeout_hours is not None
and self._latest_optimization_start_timestamp is not None
and current_timestamp_in_millis()
- none_throws(self._latest_optimization_start_timestamp)
>= none_throws(timeout_hours) * 60 * 60 * 1000
)
if timed_out:
self.logger.error(
"Optimization timed out (timeout hours: " f"{timeout_hours})!"
)
return timed_out

def report_results(self, force_refit: bool = False) -> dict[str, Any]:
"""Optional user-defined function for reporting intermediate
and final optimization results (e.g. make some API call, write to some
Expand Down Expand Up @@ -1201,10 +1200,6 @@ def run(self, max_new_trials: int, timeout_hours: float | None = None) -> bool:
completion_message,
) = self.should_consider_optimization_complete()
if optimization_complete:
self.logger.info(
completion_message
+ "`completion_criterion` is `True`, not running more trials."
)
return False

if self.should_abort_optimization(timeout_hours=timeout_hours):
Expand Down
8 changes: 4 additions & 4 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2023,7 +2023,7 @@ def test_fetch_and_process_trials_data_results_failed_objective(self) -> None:
)
self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.FAILED)

def test_completion_criterion(self) -> None:
def test_should_consider_optimization_complete(self) -> None:
# Tests non-GSS parts of the completion criterion.
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
Expand All @@ -2039,7 +2039,7 @@ def test_completion_criterion(self) -> None:
db_settings=self.db_settings_if_always_needed,
)
# With total_trials=None.
should_stop, message = scheduler.completion_criterion()
should_stop, message = scheduler.should_consider_optimization_complete()
self.assertFalse(should_stop)
self.assertEqual(message, "")

Expand All @@ -2049,7 +2049,7 @@ def test_completion_criterion(self) -> None:
**self.scheduler_options_kwargs,
)
# Experiment has fewer trials.
should_stop, message = scheduler.completion_criterion()
should_stop, message = scheduler.should_consider_optimization_complete()
self.assertFalse(should_stop)
self.assertEqual(message, "")
# Experiment has 5 trials.
Expand All @@ -2058,7 +2058,7 @@ def test_completion_criterion(self) -> None:
sobol_run = sobol_generator.gen(n=1)
self.branin_experiment.new_trial(generator_run=sobol_run)
self.assertEqual(len(self.branin_experiment.trials), 5)
should_stop, message = scheduler.completion_criterion()
should_stop, message = scheduler.should_consider_optimization_complete()
self.assertTrue(should_stop)
self.assertEqual(message, "Exceeding the total number of trials.")

Expand Down