From 4c43b1782d30f97d7570d57f3c602cb227781b93 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Tue, 24 Sep 2024 08:59:02 -0700 Subject: [PATCH] Predicted Effects Plot (#2777) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2777 This replicated the functionality of `ax.plot.scatter.plot_fitted`. It will show predicted effects for all trials with data and the most recently created non abandoned trial (which may have data also). The intent is to use it for a candidate trial which has just been generated, which will be the case in scheduler. But in the event the most recent trial is not a candidate, it should still work. To come: - setting trial index - mark arms that violate constraints - Sane Limit for number of in-sample arms? Differential Revision: D62325402 --- ax/analysis/plotly/predicted_effects.py | 223 ++++++++++++++ .../plotly/tests/test_predicted_effects.py | 287 ++++++++++++++++++ ax/utils/testing/core_stubs.py | 6 +- 3 files changed, 514 insertions(+), 2 deletions(-) create mode 100644 ax/analysis/plotly/predicted_effects.py create mode 100644 ax/analysis/plotly/tests/test_predicted_effects.py diff --git a/ax/analysis/plotly/predicted_effects.py b/ax/analysis/plotly/predicted_effects.py new file mode 100644 index 00000000000..284b9745868 --- /dev/null +++ b/ax/analysis/plotly/predicted_effects.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import chain +from typing import Any, Optional + +import pandas as pd +from ax.analysis.analysis import AnalysisCardLevel + +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.core.base_trial import BaseTrial, TrialStatus +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.core.generator_run import GeneratorRun +from ax.core.observation import ObservationFeatures +from ax.exceptions.core import UserInputError +from ax.modelbridge.base import ModelBridge +from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.modelbridge.prediction_utils import predict_at_point +from ax.utils.common.typeutils import checked_cast +from ax.utils.stats.statstools import relativize +from plotly import express as px, graph_objects as go, io as pio +from pyre_extensions import none_throws + + +class PredictedEffectsPlot(PlotlyAnalysis): + def __init__(self, metric_name: str) -> None: + """ + Args: + metric_name: The name of the metric to plot. If not specified the objective + will be used. Note that the metric cannot be inferred for + multi-objective or scalarized-objective experiments. + """ + + self.metric_name = metric_name + + def compute( + self, + experiment: Optional[Experiment] = None, + generation_strategy: Optional[GenerationStrategyInterface] = None, + ) -> PlotlyAnalysisCard: + if experiment is None: + raise UserInputError("PredictedEffectsPlot requires an Experiment.") + + generation_strategy = checked_cast( + GenerationStrategy, + generation_strategy, + exception=UserInputError( + "PredictedEffectsPlot requires a GenerationStrategy." + ), + ) + if generation_strategy.model is None: + generation_strategy._fit_current_model(data=experiment.lookup_data()) + + model = none_throws(generation_strategy.model) + + try: + trial_indices = [ + t.index + for t in experiment.trials.values() + if t.status != TrialStatus.ABANDONED + ] + candidate_trial = experiment.trials[max(trial_indices)] + except ValueError: + raise UserInputError( + f"PredictedEffectsPlot cannot be used for {experiment} " + "because it has no trials." + ) + + df = _prepare_data( + model=model, metric_name=self.metric_name, candidate_trial=candidate_trial + ) + fig = _prepare_plot(df=df, metric_name=self.metric_name) + + if ( + experiment.optimization_config is None + or self.metric_name not in experiment.optimization_config.metrics + ): + level = AnalysisCardLevel.LOW + elif self.metric_name in experiment.optimization_config.objective.metric_names: + level = AnalysisCardLevel.HIGH + else: + level = AnalysisCardLevel.MID + + return PlotlyAnalysisCard( + name="PredictedEffectsPlot", + title=f"Predicted Effects for {self.metric_name}", + subtitle="View a candidate trial and its arms' predicted metric values", + level=level, + df=df, + blob=pio.to_json(fig), + ) + + +def _get_predictions( + model: ModelBridge, + metric_name: str, + gr: Optional[GeneratorRun] = None, + trial_index: Optional[int] = None, +) -> list[dict[str, Any]]: + if gr is None: + observations = model.get_training_data() + features = [o.features for o in observations] + arm_names = [o.arm_name for o in observations] + else: + features = [ + ObservationFeatures(parameters=arm.parameters, trial_index=trial_index) + for arm in gr.arms + ] + arm_names = [a.name for a in gr.arms] + try: + predictions = [ + predict_at_point(model=model, obsf=obsf, metric_names={metric_name}) + for obsf in features + ] + except NotImplementedError: + raise UserInputError( + "PredictedEffectsPlot requires a GenerationStrategy which is " + "in a state where the current model supports prediction. The current " + f"model is {model._model_key} and does not support prediction." + ) + return [ + { + "source": "In-sample" if gr is None else gr._model_key, + "arm_name": arm_names[i], + "mean": predictions[i][0][metric_name], + "error_margin": 1.96 * predictions[i][1][metric_name], + **features[i].parameters, + } + for i in range(len(features)) + ] + + +def _get_max_observed_trial_index(model: ModelBridge) -> Optional[int]: + """Returns the max observed trial index to appease multitask models for prediction + by giving fixed features. This is not necessarily accurate and should eventually + come from the generation strategy. + """ + observed_trial_indices = [ + obs.features.trial_index + for obs in model.get_training_data() + if obs.features.trial_index is not None + ] + if len(observed_trial_indices) == 0: + return None + return max(observed_trial_indices) + + +def _prepare_data( + model: ModelBridge, metric_name: str, candidate_trial: BaseTrial +) -> pd.DataFrame: + """Prepare data for plotting. Data should include columns for: + - source: In-sample or model key that geneerated the candidate + - arm_name: Name of the arm + - mean: Predicted metric value + - error_margin: 1.96 * predicted sem for plotting 95% CI + - **PARAMETER_NAME: The value of each parameter for the arm. Will be used + for the tooltip. + There will be one row for each arm in the model's training data and one for + each arm in the generator runs of the candidate trial. If an arm is in both + the training data and the candidate trial, it will only appear once for the + candidate trial. + + Args: + model: ModelBridge being used for prediction + metric_name: Name of metric to plot + candidate_trial: Trial to plot candidates for by generator run + """ + trial_index = _get_max_observed_trial_index(model) + df = pd.DataFrame.from_records( + list( + chain( + *[ + _get_predictions(model, metric_name), + *( + [] + if candidate_trial is None + else [ + _get_predictions(model, metric_name, gr, trial_index) + for gr in candidate_trial.generator_runs + ] + ), + ] + ) + ) + ) + df.drop_duplicates(subset="arm_name", keep="last", inplace=True) + return df + + +def _get_parameter_columns(df: pd.DataFrame) -> list[str]: + """Get the names of the columns that represent parameters in df.""" + return [ + col + for col in df.columns + if col not in ["source", "arm_name", "mean", "error_margin"] + ] + + +def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure: + """Prepare a plotly figure for the predicted effects based on the data in df.""" + fig = px.scatter( + df, + x="arm_name", + y="mean", + error_y="error_margin", + color="source", + hover_data=_get_parameter_columns(df), + ) + if "status_quo" in df["arm_name"].values: + fig.add_hline( + y=df[df["arm_name"] == "status_quo"]["mean"].iloc[0], + line_width=1, + line_color="red", + ) + fig.update_layout( + xaxis={ + "tickangle": 45, + }, + ) + return fig diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py new file mode 100644 index 00000000000..5d259eee330 --- /dev/null +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -0,0 +1,287 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.predicted_effects import PredictedEffectsPlot +from ax.core.base_trial import TrialStatus +from ax.core.observation import ObservationFeatures +from ax.core.trial import Trial +from ax.exceptions.core import UserInputError +from ax.modelbridge.dispatch_utils import choose_generation_strategy +from ax.modelbridge.generation_node import GenerationNode +from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.registry import Models +from ax.modelbridge.transition_criterion import MaxTrials +from ax.utils.common.testutils import TestCase +from ax.utils.common.typeutils import checked_cast +from ax.utils.testing.core_stubs import ( + get_branin_experiment, + get_branin_metric, + get_branin_outcome_constraint, +) +from ax.utils.testing.mock import fast_botorch_optimize +from pyre_extensions import none_throws + + +class TestParallelCoordinatesPlot(TestCase): + def setUp(self) -> None: + super().setUp() + self.generation_strategy = GenerationStrategy( + nodes=[ + GenerationNode( + node_name="Sobol", + model_specs=[ModelSpec(model_enum=Models.SOBOL)], + transition_criteria=[ + MaxTrials( + threshold=1, + transition_to="GPEI", + ) + ], + ), + GenerationNode( + node_name="GPEI", + model_specs=[ + ModelSpec( + model_enum=Models.BOTORCH_MODULAR, + ), + ], + transition_criteria=[ + MaxTrials( + threshold=1, + transition_to="MTGP", + only_in_statuses=[ + TrialStatus.RUNNING, + TrialStatus.COMPLETED, + TrialStatus.EARLY_STOPPED, + ], + ) + ], + ), + GenerationNode( + node_name="MTGP", + model_specs=[ + ModelSpec( + model_enum=Models.ST_MTGP, + ), + ], + ), + ], + ) + + def test_compute_for_invalid_states(self) -> None: + analysis = PredictedEffectsPlot(metric_name="branin") + experiment = get_branin_experiment() + generation_strategy = choose_generation_strategy( + search_space=experiment.search_space, + experiment=experiment, + ) + + with self.assertRaisesRegex(UserInputError, "requires an Experiment"): + analysis.compute() + + with self.assertRaisesRegex(UserInputError, "requires a GenerationStrategy"): + analysis.compute(experiment=experiment) + + with self.assertRaisesRegex(UserInputError, "it has no trials"): + analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + + with self.assertRaisesRegex( + UserInputError, "where the current model supports prediction" + ): + experiment = get_branin_experiment( + with_batch=True, with_completed_batch=True + ) + analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + + @fast_botorch_optimize + def test_compute(self) -> None: + # GIVEN an experiment with metrics and batch trials + experiment = get_branin_experiment(with_status_quo=True) + none_throws(experiment.optimization_config).outcome_constraints = [ + get_branin_outcome_constraint(name="constraint_branin") + ] + experiment.add_tracking_metric(get_branin_metric(name="tracking_branin")) + generation_strategy = self.generation_strategy + experiment.new_batch_trial( + generator_run=generation_strategy.gen(experiment=experiment, n=10) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1.0 + ).mark_completed( + unsafe=True + ) + experiment.fetch_data() + experiment.new_batch_trial( + generator_run=generation_strategy.gen(experiment=experiment, n=10) + ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) + experiment.fetch_data() + for metric in experiment.metrics: + with self.subTest(metric=metric): + # WHEN we compute the analysis for a metric + analysis = PredictedEffectsPlot(metric_name=metric) + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + # THEN it makes a card with the right name, title, and subtitle + self.assertEqual(card.name, "PredictedEffectsPlot") + self.assertEqual(card.title, f"Predicted Effects for {metric}") + self.assertEqual( + card.subtitle, + "View a candidate trial and its arms' predicted metric values", + ) + # AND THEN it has an appropriate level based on whether we're + # optimizing for the metric + self.assertEqual( + card.level, + ( + AnalysisCardLevel.HIGH + if metric == "branin" + else ( + AnalysisCardLevel.MID + if metric == "constraint_branin" + else AnalysisCardLevel.LOW + ) + ), + ) + # AND THEN it has the right rows and columns in the dataframe + self.assertEqual( + {*card.df.columns}, + {"arm_name", "source", "x1", "x2", "mean", "error_margin"}, + ) + self.assertIsNotNone(card.blob) + self.assertEqual(card.blob_annotation, "plotly") + for trial in experiment.trials.values(): + for arm in trial.arms: + self.assertIn(arm.name, card.df["arm_name"].unique()) + + @fast_botorch_optimize + def test_compute_multitask(self) -> None: + # GIVEN an experiment with candidates generated with a multitask model + experiment = get_branin_experiment() + generation_strategy = self.generation_strategy + experiment.new_batch_trial( + generator_run=generation_strategy.gen(experiment=experiment, n=10) + ).mark_completed(unsafe=True) + experiment.fetch_data() + experiment.new_batch_trial( + generator_run=generation_strategy.gen(experiment=experiment, n=10) + ).mark_completed(unsafe=True) + experiment.fetch_data() + # leave as a candidate + experiment.new_batch_trial( + generator_run=generation_strategy.gen( + experiment=experiment, + n=10, + fixed_features=ObservationFeatures(parameters={}, trial_index=1), + ) + ) + experiment.new_batch_trial( + generator_run=generation_strategy.gen( + experiment=experiment, + n=10, + fixed_features=ObservationFeatures(parameters={}, trial_index=1), + ) + ) + self.assertEqual(none_throws(generation_strategy.model)._model_key, "ST_MTGP") + # WHEN we compute the analysis + analysis = PredictedEffectsPlot(metric_name="branin") + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + # THEN it has the right rows for arms with data, as well as the latest trial + arms_with_data = set(experiment.lookup_data().df["arm_name"].unique()) + max_trial_index = max(experiment.trials.keys()) + for trial in experiment.trials.values(): + if trial.status.expecting_data or trial.index == max_trial_index: + for arm in trial.arms: + self.assertIn(arm.name, card.df["arm_name"].unique()) + else: + # arms from other candidate trials are only in the df if they + # are repeated in the target trial + for arm in trial.arms: + self.assertTrue( + arm.name not in card.df["arm_name"].unique() + # it's repeated in another trial + or arm.name in arms_with_data + or arm.name in experiment.trials[max_trial_index].arms_by_name, + arm.name, + ) + + @fast_botorch_optimize + def test_it_does_not_plot_abandoned_trials(self) -> None: + # GIVEN an experiment with candidate and abandoned trials + experiment = get_branin_experiment() + generation_strategy = self.generation_strategy + experiment.new_batch_trial( + generator_run=generation_strategy.gen(experiment=experiment, n=10) + ).mark_completed(unsafe=True) + experiment.fetch_data() + # candidate trial + experiment.new_batch_trial( + generator_run=generation_strategy.gen(experiment=experiment, n=10) + ) + experiment.new_batch_trial( + generator_run=generation_strategy.gen(experiment=experiment, n=10) + ).mark_abandoned() + arms_with_data = set(experiment.lookup_data().df["arm_name"].unique()) + # WHEN we compute the analysis + analysis = PredictedEffectsPlot(metric_name="branin") + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + # THEN it has the right rows for arms with data, as well as the latest + # non abandoned trial (with index 1) + for arm in experiment.trials[0].arms + experiment.trials[1].arms: + self.assertIn(arm.name, card.df["arm_name"].unique()) + + # AND THEN it does not have the arms from the abandoned trial (index 2) + for arm in experiment.trials[2].arms: + self.assertTrue( + arm.name not in card.df["arm_name"].unique() + # it's repeated in another trial + or arm.name in arms_with_data + or arm.name in experiment.trials[1].arms_by_name, + arm.name, + ) + + @fast_botorch_optimize + def test_it_works_for_non_batch_experiments(self) -> None: + # GIVEN an experiment with the default generation strategy + experiment = get_branin_experiment(with_batch=False) + generation_strategy = choose_generation_strategy( + search_space=experiment.search_space, + experiment=experiment, + ) + # AND GIVEN we generate all Sobol trials and one GPEI trial + sobol_key = Models.SOBOL.value + last_model_key = sobol_key + while last_model_key == sobol_key: + trial = experiment.new_trial( + generator_run=generation_strategy.gen( + experiment=experiment, n=1, pending_observation=True + ) + ) + last_model_key = none_throws(trial.generator_run)._model_key + if last_model_key == sobol_key: + trial.mark_running(no_runner_required=True) + trial.mark_completed() + trial.fetch_data() + + # WHEN we compute the analysis + analysis = PredictedEffectsPlot(metric_name="branin") + card = analysis.compute( + experiment=experiment, + generation_strategy=generation_strategy, + ) + # THEN it has all arms represented in the dataframe + for trial in experiment.trials.values(): + self.assertIn( + none_throws(checked_cast(Trial, trial).arm).name, + card.df["arm_name"].unique(), + ) diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 2823ff7a6e6..843746f27ee 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -1542,8 +1542,10 @@ def get_scalarized_outcome_constraint() -> ScalarizedOutcomeConstraint: ) -def get_branin_outcome_constraint() -> OutcomeConstraint: - return OutcomeConstraint(metric=get_branin_metric(), op=ComparisonOp.LEQ, bound=0.0) +def get_branin_outcome_constraint(name: str = "branin") -> OutcomeConstraint: + return OutcomeConstraint( + metric=get_branin_metric(name=name), op=ComparisonOp.LEQ, bound=0.0 + ) ##############################