From a4a97d69e3e113d568fe89df92ae9eeda7eeb1f4 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Wed, 11 Sep 2024 12:00:18 -0700 Subject: [PATCH] Remove unnecessary wrapper-helper around early-stopping call (#2666) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2666 As titled Reviewed By: saitcakmak, Balandat Differential Revision: D55710758 fbshipit-source-id: afb1ca0d971bf79f470d58e0e3b4584588cbd962 --- ax/service/scheduler.py | 22 +----- ax/utils/testing/backend_scheduler.py | 109 -------------------------- 2 files changed, 3 insertions(+), 128 deletions(-) delete mode 100644 ax/utils/testing/backend_scheduler.py diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 32beeabab17..188d5d2d79b 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -1233,8 +1233,10 @@ def poll_and_process_results(self, poll_all_trial_statuses: bool = False) -> boo trial_indices_with_updated_data_or_status.update(trial_indices_with_new_data) # EARLY STOP TRIALS - stop_trial_info = self.should_stop_trials_early( + stop_trial_info = early_stopping_utils.should_stop_trials_early( + early_stopping_strategy=self.options.early_stopping_strategy, trial_indices=self.experiment.running_trial_indices, + experiment=self.experiment, ) self.stop_trial_runs( trials=[self.experiment.trials[trial_idx] for trial_idx in stop_trial_info], @@ -1428,24 +1430,6 @@ def _process_completed_trials(self, newly_completed: set[int]) -> None: trial_indices=newly_completed, ) - def should_stop_trials_early( - self, trial_indices: set[int] - ) -> dict[int, Optional[str]]: - """Evaluate whether to early-stop running trials. - - Args: - trial_indices: Indices of trials to consider for early stopping. - - Returns: - A dictionary mapping trial indices that should be early stopped to - (optional) messages with the associated reason. - """ - return early_stopping_utils.should_stop_trials_early( - early_stopping_strategy=self.options.early_stopping_strategy, - trial_indices=trial_indices, - experiment=self.experiment, - ) - def estimate_early_stopping_savings(self, map_key: Optional[str] = None) -> float: """Estimate early stopping savings using progressions of the MapMetric present on the EarlyStoppingConfig as a proxy for resource usage. diff --git a/ax/utils/testing/backend_scheduler.py b/ax/utils/testing/backend_scheduler.py deleted file mode 100644 index c4b87afbd78..00000000000 --- a/ax/utils/testing/backend_scheduler.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from __future__ import annotations - -from dataclasses import replace as dataclass_replace - -from logging import Logger -from typing import Optional - -from ax.core.experiment import Experiment -from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.runners.simulated_backend import SimulatedBackendRunner -from ax.service.scheduler import Scheduler, SchedulerOptions -from ax.utils.common.logger import get_logger -from ax.utils.testing.backend_simulator import BackendSimulator - -logger: Logger = get_logger(__name__) - - -class AsyncSimulatedBackendScheduler(Scheduler): - """A Scheduler that uses a simulated backend for Ax asynchronous benchmarks.""" - - def __init__( - self, - experiment: Experiment, - generation_strategy: GenerationStrategy, - max_pending_trials: int, - options: SchedulerOptions, - ) -> None: - """A Scheduler for Ax asynchronous benchmarks. - - Args: - experiment: Experiment, in which results of the optimization - will be recorded. - generation_strategy: Generation strategy for the optimization, - describes models that will be used in optimization. - max_pending_trials: The maximum number of pending trials allowed. - options: `SchedulerOptions` for this Scheduler instance. - """ - if not isinstance(experiment.runner, SimulatedBackendRunner): - raise ValueError( - "experiment must have runner of type SimulatedBackendRunner attached" - ) - if ( - options.max_pending_trials is not None - and options.max_pending_trials != max_pending_trials - ): - raise ValueError( - f"`SchedulerOptions.max_pending_trials`: {options.max_pending_trials} " - f"does not match argument to `Scheduler`: {max_pending_trials}." - ) - if options.max_pending_trials is None: - options = dataclass_replace(options, max_pending_trials=max_pending_trials) - - super().__init__( - experiment=experiment, - generation_strategy=generation_strategy, - options=options, - _skip_experiment_save=True, - ) - - @property - def backend_simulator(self) -> BackendSimulator: - """Get the ``BackendSimulator`` stored on the runner of the experiment. - - Returns: - The backend simulator. - """ - return self.runner.simulator # pyre-ignore[16] - - def should_stop_trials_early( - self, trial_indices: set[int] - ) -> dict[int, Optional[str]]: - """Given a set of trial indices, decide whether or not to early-stop - running trials using the ``early_stopping_strategy``. - - Args: - trial_indices: Indices of trials to consider for early stopping. - - Returns: - Dict with new suggested ``TrialStatus`` as keys and a set of - indices of trials to update (subset of initially-passed trials) as values. - """ - # TODO: The status on the experiment does not distinguish between - # running and queued trials, so here we check status on the - # ``backend_simulator`` directly to make sure it is running. - running_trials = set() - skipped_trials = set() - for trial_index in trial_indices: - sim_trial = self.backend_simulator.get_sim_trial_by_index(trial_index) - if sim_trial.sim_start_time is not None and ( # pyre-ignore[16] - self.backend_simulator.time - sim_trial.sim_start_time > 0 - ): - running_trials.add(trial_index) - else: - skipped_trials.add(trial_index) - if len(skipped_trials) > 0: - logger.info( - f"Not sending {skipped_trials} to base `should_stop_trials_early` " - "because they have not been running for a positive amount of time " - "on the backend simulator." - ) - return super().should_stop_trials_early(trial_indices=running_trials)