diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 4faf63b3f21..1b9066e7456 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -77,8 +77,10 @@ """ FAILURE_EXCEEDED_MSG = ( "Failure rate exceeds the tolerated trial failure rate of {f_rate} (at least " - "{n_failed} out of first {n_ran} trials failed). Checks are triggered both at " - "the end of a optimization and if at least {min_failed} trials have failed." + "{n_failed} out of first {n_ran} trials failed or were abandoned). Checks are " + "triggered both at the end of a optimization and if at least {min_failed} trials " + "have either failed, or have been abandoned, potentially automatically due to " + "issues with the trial." ) @@ -828,16 +830,7 @@ def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None: five failed trials. If True, the check will be performed unless there are 0 failures. """ - bad_idcs = ( - self.experiment.trial_indices_by_status[TrialStatus.FAILED] - | self.experiment.trial_indices_by_status[TrialStatus.ABANDONED] - ) - # We only count failed trials with indices that came after the preexisting - # trials on experiment before scheduler use. - num_bad_in_scheduler = sum( - 1 for f in bad_idcs if f >= self._num_preexisting_trials - ) - + num_bad_in_scheduler = self._num_bad_in_scheduler() # skip check if 0 failures if num_bad_in_scheduler == 0: return @@ -850,10 +843,7 @@ def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None: ): return - num_ran_in_scheduler = ( - len(self.experiment.trials) - self._num_preexisting_trials - ) - + num_ran_in_scheduler = self._num_ran_in_scheduler() failure_rate_exceeded = ( num_bad_in_scheduler / num_ran_in_scheduler ) > self.options.tolerated_trial_failure_rate @@ -872,6 +862,26 @@ def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None: num_ran_in_scheduler=num_ran_in_scheduler, ) + def _num_bad_in_scheduler(self) -> int: + """Returns the number of trials that have failed or been abandoned in the + scheduler. + """ + bad_idcs = ( + self.experiment.trial_indices_by_status[TrialStatus.FAILED] + | self.experiment.trial_indices_by_status[TrialStatus.ABANDONED] + ) + # We only count failed trials with indices that came after the preexisting + # trials on experiment before scheduler use. + return sum(1 for f in bad_idcs if f >= self._num_preexisting_trials) + + def _num_ran_in_scheduler(self) -> int: + """Returns the number of trials that have been run by the scheduler.""" + return sum( + 1 + for idx, t in self.experiment.trials.items() + if idx >= self._num_preexisting_trials and t.status.is_terminal + ) + def run_trials_and_yield_results( self, max_trials: int,