diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 1128f7d1f2a..55388093674 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -1489,85 +1489,92 @@ def _apply_new_trial_statuses( trial.mark_as(status=status, unsafe=True) return updated_trial_indices - def _get_trial_indices_to_fetch( - self, new_status_to_trial_idcs: Mapping[TrialStatus, set[int]] + def _identify_trial_indices_to_fetch( + self, + old_status_to_trial_indices: Mapping[TrialStatus, set[int]], + new_status_to_trial_indices: Mapping[TrialStatus, set[int]], ) -> set[int]: - """Get trial indices to fetch data for the experiment given - `new_status_to_trial_idcs` and metric properties. This should include: - - newly completed trials (about to be completed) - - running trials if the experiment has metrics available while running - - previously completed (or early stopped) trials if the experiment - has metrics with new data after completion which finished recently - + """ + Identify trial indices to fetch data for based on changes in trial statuses. Args: - new_status_to_trial_idcs: Changes about to be applied to trial statuses. - + old_status_to_trial_indices: Mapping of old trial statuses + to their corresponding trial indices. + new_status_to_trial_indices: Mapping of new trial statuses + to their corresponding trial indices. Returns: Set of trial indices to fetch data for. """ - terminated_trial_idcs = { - index - for status, indices in new_status_to_trial_idcs.items() - if status.is_terminal - for index in indices - } - running_trial_indices = { - trial.index - for trial in self.running_trials - if trial.index not in terminated_trial_idcs - } - # add in any trials that will be marked running - running_trial_indices.update( - new_status_to_trial_idcs.get(TrialStatus.RUNNING, set()) - ) - - # includes completed and early stopped trials - prev_completed_trial_idcs = { - t.index for t in self.trials_expecting_data - } - self.running_trial_indices - trial_indices_to_fetch = set() - - # Fetch data for newly completed trials - newly_completed = ( - new_status_to_trial_idcs.get(TrialStatus.COMPLETED, set()) - - prev_completed_trial_idcs - ) + # Get newly completed trials + newly_completed = new_status_to_trial_indices.get(TrialStatus.COMPLETED, set()) idcs = make_indices_str(indices=newly_completed) if newly_completed: self.logger.info(f"Fetching data for newly completed trials: {idcs}.") - trial_indices_to_fetch.update(newly_completed) else: self.logger.info("No newly completed trials; not fetching data for any.") - # Fetch data for running trials that have metrics available while running - if ( - any( - m.is_available_while_running() for m in self.experiment.metrics.values() - ) - and len(running_trial_indices) > 0 + # Get running trials with metrics available while running + running_trial_indices_with_metrics = set() + if any( + m.is_available_while_running() for m in self.experiment.metrics.values() ): - # NOTE: Metrics that are *not* available_while_running will be skipped - # in fetch_trials_data - idcs = make_indices_str(indices=running_trial_indices) - self.logger.info( - f"Fetching data for trials: {idcs} because some metrics " - "on experiment are available while trials are running." - ) - trial_indices_to_fetch.update(running_trial_indices) + running_trial_indices_with_metrics = new_status_to_trial_indices.get( + TrialStatus.RUNNING, set() + ) | old_status_to_trial_indices.get(TrialStatus.RUNNING, set()) + + for status, indices in new_status_to_trial_indices.items(): + if status.is_terminal and indices: + running_trial_indices_with_metrics -= indices - # Fetch data for previously completed trials that have metrics available - # after trial completion that were completed within the max of the period - # specified by metrics + if running_trial_indices_with_metrics: + idcs = make_indices_str(indices=running_trial_indices_with_metrics) + self.logger.info( + f"Fetching data for trials: {idcs} because some metrics " + "on experiment are available while trials are running." + ) + + # Get previously completed trials with new data after completion recently_completed_trial_indices = self._get_recently_completed_trial_indices() - if len(recently_completed_trial_indices) > 0: + if recently_completed_trial_indices: idcs = make_indices_str(indices=recently_completed_trial_indices) self.logger.info( f"Fetching data for trials: {idcs} because some metrics " "on experiment have new data after completion." ) - trial_indices_to_fetch.update(recently_completed_trial_indices) + + # Combine all trial indices to fetch data for + trial_indices_to_fetch = ( + newly_completed + | running_trial_indices_with_metrics + | recently_completed_trial_indices + ) + return trial_indices_to_fetch + def _get_trial_indices_to_fetch( + self, new_status_to_trial_idcs: Mapping[TrialStatus, set[int]] + ) -> set[int]: + """Get trial indices to fetch data for the experiment given + `new_status_to_trial_idcs` and metric properties. This should include: + - newly completed trials (about to be completed) + - running trials if the experiment has metrics available while running + - previously completed (or early stopped) trials if the experiment + has metrics with new data after completion which finished recently + + Args: + new_status_to_trial_idcs: Changes about to be applied to trial statuses. + + Returns: + Set of trial indices to fetch data for. + """ + old_status_to_trial_idcs = {status: set() for status in TrialStatus} + + for trial in self.trials: + old_status_to_trial_idcs[trial.status].add(trial.index) + + return self._identify_trial_indices_to_fetch( + old_status_to_trial_idcs, new_status_to_trial_idcs + ) + def _get_recently_completed_trial_indices(self) -> set[int]: """Get trials that have completed within the max period specified by metrics.""" if len(self.experiment.metrics) == 0: