From bd4fe421aa1d8be2b8b518eeb3204c138f6dd059 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Wed, 6 Dec 2023 08:31:56 -0800 Subject: [PATCH] add improvement_to_baseline for SOO cases (#2046) Summary: Adding functionality for adding "improvement_over_baseline" from a scheduler object. Reviewed By: mpolson64 Differential Revision: D51726168 --- ax/service/scheduler.py | 65 +++++++++++++++++++++ ax/service/tests/test_scheduler.py | 93 ++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index fde94b24cde..07bf322544e 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -543,6 +543,71 @@ def summarize_final_result(self) -> OptimizationResult: """ return OptimizationResult() + def get_improvement_over_baseline( + self, + baseline_arm_name: Optional[str] = None, + ) -> float: + """Returns the scalarized improvement over baseline, if applicable. + + Returns: + For Single Objective cases, returns % improvement of objective. + Positive indicates improvement over baseline. Negative indicates regression. + For Multi Objective cases, throws NotImplementedError + """ + if self.experiment.is_moo_problem: + raise NotImplementedError( + "`get_improvement_over_baseline` not yet implemented" + + " for multi-objective problems." + ) + if not baseline_arm_name: + raise UserInputError( + "`get_improvement_over_baseline` missing required parameter: " + + f"{baseline_arm_name=}, " + ) + + optimization_config = self.experiment.optimization_config + if not optimization_config: + raise ValueError("No optimization config found.") + + objective_metric_name = optimization_config.objective.metric.name + + # get the baseline trial + data = self.experiment.lookup_data().df + data = data[data["arm_name"] == baseline_arm_name] + if len(data) == 0: + raise UserInputError( + "`get_improvement_over_baseline`" + " could not find baseline arm" + f" `{baseline_arm_name}` in the experiment data." + ) + data = data[data["metric_name"] == objective_metric_name] + baseline_value = data.iloc[0]["mean"] + + # Find objective value of the best trial + idx, param, best_arm = not_none( + self.get_best_trial( + optimization_config=optimization_config, use_model_predictions=False + ) + ) + best_arm = not_none(best_arm) + best_obj_value = best_arm[0][objective_metric_name] + + def percent_change(x: float, y: float, minimize: bool) -> float: + if x == 0: + raise ZeroDivisionError( + "Cannot compute percent improvement when denom is zero" + ) + percent_change = (y - x) / abs(x) * 100 + if minimize: + percent_change = -percent_change + return percent_change + + return percent_change( + x=baseline_value, + y=best_obj_value, + minimize=optimization_config.objective.minimize, + ) + # ---------- Methods below should generally not be modified in subclasses. --------- @retry_on_exception(retries=3, no_retry_on_exception_types=NO_RETRY_EXCEPTIONS) diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index 9626498caa2..edf3090d01c 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -1384,3 +1384,96 @@ def test_standard_generation_strategy(self) -> None: "only supported with instances of `GenerationStrategy`", ): scheduler.standard_generation_strategy + + def test_get_improvement_over_baseline(self) -> None: + n_total_trials = 8 + + scheduler = Scheduler( + experiment=self.branin_experiment, # Has runner and metrics. + generation_strategy=self.two_sobol_steps_GS, + options=SchedulerOptions( + total_trials=n_total_trials, + # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`. + init_seconds_between_polls=0.1, # Short between polls so test is fast. + ), + ) + + scheduler.run_all_trials() + + first_trial_name = ( + scheduler.experiment.trials[0].lookup_data().df["arm_name"].iloc[0] + ) + percent_improvement = scheduler.get_improvement_over_baseline( + baseline_arm_name=first_trial_name, + ) + + # Assert that the best trial improves, or + # at least doesn't regress, over the first trial. + self.assertGreaterEqual(percent_improvement, 0.0) + + def test_get_improvement_over_baseline_robustness(self) -> None: + """Test edge cases for get_improvement_over_baseline""" + experiment = get_branin_experiment_with_multi_objective() + experiment.runner = self.runner + + scheduler = Scheduler( + experiment=experiment, + generation_strategy=self.sobol_GPEI_GS, + # pyre-fixme[6]: For 1st param expected `Optional[int]` but got `float`. + options=SchedulerOptions(init_seconds_between_polls=0.1), + ) + + with self.assertRaises(NotImplementedError): + scheduler.get_improvement_over_baseline( + baseline_arm_name=None, + ) + + scheduler = Scheduler( + experiment=self.branin_experiment, # Has runner and metrics. + generation_strategy=self.two_sobol_steps_GS, + options=SchedulerOptions( + total_trials=2, + # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`. + init_seconds_between_polls=0.1, # Short between polls so test is fast. + ), + ) + + with self.assertRaises(UserInputError): + scheduler.get_improvement_over_baseline( + baseline_arm_name=None, + ) + + exp = scheduler.experiment + exp_copy = Experiment( + search_space=exp.search_space, + name=exp.name, + optimization_config=None, + tracking_metrics=exp.tracking_metrics, + runner=exp.runner, + ) + scheduler.experiment = exp_copy + + with self.assertRaises(ValueError): + scheduler.get_improvement_over_baseline(baseline_arm_name="baseline") + + def test_get_improvement_over_baseline_no_baseline(self) -> None: + """Test that get_improvement_over_baseline returns UserInputError when + baseline is not found in data.""" + n_total_trials = 8 + + scheduler = Scheduler( + experiment=self.branin_experiment, # Has runner and metrics. + generation_strategy=self.two_sobol_steps_GS, + options=SchedulerOptions( + total_trials=n_total_trials, + # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`. + init_seconds_between_polls=0.1, # Short between polls so test is fast. + ), + ) + + scheduler.run_all_trials() + + with self.assertRaises(UserInputError): + scheduler.get_improvement_over_baseline( + baseline_arm_name="baseline_arm_not_in_data", + )