Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor _get_trial_indices_to_fetch #3086

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 65 additions & 58 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,85 +1489,92 @@ def _apply_new_trial_statuses(
trial.mark_as(status=status, unsafe=True)
return updated_trial_indices

def _get_trial_indices_to_fetch(
self, new_status_to_trial_idcs: Mapping[TrialStatus, set[int]]
def _identify_trial_indices_to_fetch(
self,
old_status_to_trial_indices: Mapping[TrialStatus, set[int]],
new_status_to_trial_indices: Mapping[TrialStatus, set[int]],
) -> set[int]:
"""Get trial indices to fetch data for the experiment given
`new_status_to_trial_idcs` and metric properties. This should include:
- newly completed trials (about to be completed)
- running trials if the experiment has metrics available while running
- previously completed (or early stopped) trials if the experiment
has metrics with new data after completion which finished recently

"""
Identify trial indices to fetch data for based on changes in trial statuses.
Args:
new_status_to_trial_idcs: Changes about to be applied to trial statuses.

old_status_to_trial_indices: Mapping of old trial statuses
to their corresponding trial indices.
new_status_to_trial_indices: Mapping of new trial statuses
to their corresponding trial indices.
Returns:
Set of trial indices to fetch data for.
"""
terminated_trial_idcs = {
index
for status, indices in new_status_to_trial_idcs.items()
if status.is_terminal
for index in indices
}
running_trial_indices = {
trial.index
for trial in self.running_trials
if trial.index not in terminated_trial_idcs
}
# add in any trials that will be marked running
running_trial_indices.update(
new_status_to_trial_idcs.get(TrialStatus.RUNNING, set())
)

# includes completed and early stopped trials
prev_completed_trial_idcs = {
t.index for t in self.trials_expecting_data
} - self.running_trial_indices
trial_indices_to_fetch = set()

# Fetch data for newly completed trials
newly_completed = (
new_status_to_trial_idcs.get(TrialStatus.COMPLETED, set())
- prev_completed_trial_idcs
)
# Get newly completed trials
newly_completed = new_status_to_trial_indices.get(TrialStatus.COMPLETED, set())
idcs = make_indices_str(indices=newly_completed)
if newly_completed:
self.logger.info(f"Fetching data for newly completed trials: {idcs}.")
trial_indices_to_fetch.update(newly_completed)
else:
self.logger.info("No newly completed trials; not fetching data for any.")

# Fetch data for running trials that have metrics available while running
if (
any(
m.is_available_while_running() for m in self.experiment.metrics.values()
)
and len(running_trial_indices) > 0
# Get running trials with metrics available while running
running_trial_indices_with_metrics = set()
if any(
m.is_available_while_running() for m in self.experiment.metrics.values()
):
# NOTE: Metrics that are *not* available_while_running will be skipped
# in fetch_trials_data
idcs = make_indices_str(indices=running_trial_indices)
self.logger.info(
f"Fetching data for trials: {idcs} because some metrics "
"on experiment are available while trials are running."
)
trial_indices_to_fetch.update(running_trial_indices)
running_trial_indices_with_metrics = new_status_to_trial_indices.get(
TrialStatus.RUNNING, set()
) | old_status_to_trial_indices.get(TrialStatus.RUNNING, set())

for status, indices in new_status_to_trial_indices.items():
if status.is_terminal and indices:
running_trial_indices_with_metrics -= indices

# Fetch data for previously completed trials that have metrics available
# after trial completion that were completed within the max of the period
# specified by metrics
if running_trial_indices_with_metrics:
idcs = make_indices_str(indices=running_trial_indices_with_metrics)
self.logger.info(
f"Fetching data for trials: {idcs} because some metrics "
"on experiment are available while trials are running."
)

# Get previously completed trials with new data after completion
recently_completed_trial_indices = self._get_recently_completed_trial_indices()
if len(recently_completed_trial_indices) > 0:
if recently_completed_trial_indices:
idcs = make_indices_str(indices=recently_completed_trial_indices)
self.logger.info(
f"Fetching data for trials: {idcs} because some metrics "
"on experiment have new data after completion."
)
trial_indices_to_fetch.update(recently_completed_trial_indices)

# Combine all trial indices to fetch data for
trial_indices_to_fetch = (
newly_completed
| running_trial_indices_with_metrics
| recently_completed_trial_indices
)

return trial_indices_to_fetch

def _get_trial_indices_to_fetch(
self, new_status_to_trial_idcs: Mapping[TrialStatus, set[int]]
) -> set[int]:
"""Get trial indices to fetch data for the experiment given
`new_status_to_trial_idcs` and metric properties. This should include:
- newly completed trials (about to be completed)
- running trials if the experiment has metrics available while running
- previously completed (or early stopped) trials if the experiment
has metrics with new data after completion which finished recently

Args:
new_status_to_trial_idcs: Changes about to be applied to trial statuses.

Returns:
Set of trial indices to fetch data for.
"""
old_status_to_trial_idcs = {status: set() for status in TrialStatus}

for trial in self.trials:
old_status_to_trial_idcs[trial.status].add(trial.index)

return self._identify_trial_indices_to_fetch(
old_status_to_trial_idcs, new_status_to_trial_idcs
)

def _get_recently_completed_trial_indices(self) -> set[int]:
"""Get trials that have completed within the max period specified by metrics."""
if len(self.experiment.metrics) == 0:
Expand Down