From 58008063a401567cfd5ddff37579d1cfbaee8420 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Tue, 20 Aug 2024 08:45:21 -0700 Subject: [PATCH] Migrate Jenatton to use BenchmarkRunner and BenchmarkMetric (#2676) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2676 This PR: - Has Jenatton use `ParamBasedTestProblem` so that it can use `ParamBasedProblemRunner`, and also have it use `BenchmarkMetric`; get rid of specialized Jenatton runners and metrics. This enables Jenatton to handle noisy problems, whether noise levels are observed or not, like other benchmark problems, and will make it easy to add constraints or benefit from other new functionality. - Does *not* clean up the now-unnecessary Jennaton metric file; that happens in the next diff. Differential Revision: D61502458 --- ax/benchmark/metrics/jenatton.py | 69 +---- .../problems/synthetic/hss/jenatton.py | 56 ++++- ax/benchmark/tests/metrics/test_jennaton.py | 236 +++++++++++++----- ax/benchmark/tests/test_benchmark.py | 22 +- ax/storage/json_store/registry.py | 3 - .../json_store/tests/test_json_store.py | 4 +- 6 files changed, 237 insertions(+), 153 deletions(-) diff --git a/ax/benchmark/metrics/jenatton.py b/ax/benchmark/metrics/jenatton.py index dd1da3205fd..4be7f5eae1b 100644 --- a/ax/benchmark/metrics/jenatton.py +++ b/ax/benchmark/metrics/jenatton.py @@ -5,78 +5,11 @@ # pyre-strict -from __future__ import annotations +from typing import Optional -from typing import Any, Optional - -import numpy as np -import pandas as pd -from ax.benchmark.metrics.base import BenchmarkMetricBase, GroundTruthMetricMixin -from ax.core.base_trial import BaseTrial -from ax.core.data import Data -from ax.core.metric import MetricFetchE, MetricFetchResult -from ax.utils.common.result import Err, Ok from ax.utils.common.typeutils import not_none -class JenattonMetric(BenchmarkMetricBase): - """Jenatton metric for hierarchical search spaces.""" - - has_ground_truth: bool = True - - def __init__( - self, - name: str = "jenatton", - noise_std: float = 0.0, - observe_noise_sd: bool = False, - ) -> None: - super().__init__(name=name) - self.noise_std = noise_std - self.observe_noise_sd = observe_noise_sd - self.lower_is_better = True - - def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: - try: - mean = [ - jenatton_test_function(**arm.parameters) # pyre-ignore [6] - for _, arm in trial.arms_by_name.items() - ] - if self.noise_std != 0: - mean = [m + self.noise_std * np.random.randn() for m in mean] - df = pd.DataFrame( - { - "arm_name": [name for name, _ in trial.arms_by_name.items()], - "metric_name": self.name, - "mean": mean, - "sem": self.noise_std if self.observe_noise_sd else None, - "trial_index": trial.index, - } - ) - return Ok(value=Data(df=df)) - - except Exception as e: - return Err( - MetricFetchE(message=f"Failed to fetch {self.name}", exception=e) - ) - - def make_ground_truth_metric(self) -> GroundTruthJenattonMetric: - return GroundTruthJenattonMetric(original_metric=self) - - -class GroundTruthJenattonMetric(JenattonMetric, GroundTruthMetricMixin): - def __init__(self, original_metric: JenattonMetric) -> None: - """ - Args: - original_metric: The original JenattonMetric to which this metric - corresponds. - """ - super().__init__( - name=self.get_ground_truth_name(original_metric), - noise_std=0.0, - observe_noise_sd=False, - ) - - def jenatton_test_function( x1: Optional[int] = None, x2: Optional[int] = None, diff --git a/ax/benchmark/problems/synthetic/hss/jenatton.py b/ax/benchmark/problems/synthetic/hss/jenatton.py index f545ac39400..67fa1755372 100644 --- a/ax/benchmark/problems/synthetic/hss/jenatton.py +++ b/ax/benchmark/problems/synthetic/hss/jenatton.py @@ -5,18 +5,52 @@ # pyre-strict +from dataclasses import dataclass +from typing import Optional + +import torch from ax.benchmark.benchmark_problem import BenchmarkProblem -from ax.benchmark.metrics.jenatton import JenattonMetric +from ax.benchmark.metrics.benchmark import BenchmarkMetric +from ax.benchmark.metrics.jenatton import jenatton_test_function +from ax.benchmark.runners.botorch_test import ( + ParamBasedTestProblem, + ParamBasedTestProblemRunner, +) from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import HierarchicalSearchSpace -from ax.runners.synthetic import SyntheticRunner +from ax.core.types import TParameterization + + +@dataclass(kw_only=True) +class Jenatton(ParamBasedTestProblem): + r"""Jenatton test function for hierarchical search spaces. + + This function is taken from: + + R. Jenatton, C. Archambeau, J. González, and M. Seeger. Bayesian + optimization with tree-structured dependencies. ICML 2017. + """ + + noise_std: Optional[float] = None + negate: bool = False + num_objectives: int = 1 + optimal_value: float = 0.1 + _is_constrained: bool = False + + def evaluate_true(self, params: TParameterization) -> torch.Tensor: + # pyre-fixme: Incompatible parameter type [6]: In call + # `jenatton_test_function`, for 1st positional argument, expected + # `Optional[float]` but got `Union[None, bool, float, int, str]`. + value = jenatton_test_function(**params) + return torch.tensor(value) def get_jenatton_benchmark_problem( num_trials: int = 50, observe_noise_sd: bool = False, + noise_std: float = 0.0, ) -> BenchmarkProblem: search_space = HierarchicalSearchSpace( parameters=[ @@ -55,24 +89,28 @@ def get_jenatton_benchmark_problem( ), ] ) + name = "Jenatton" + ("_observed_noise" if observe_noise_sd else "") optimization_config = OptimizationConfig( objective=Objective( - metric=JenattonMetric(observe_noise_sd=observe_noise_sd), + metric=BenchmarkMetric( + name=name, observe_noise_sd=observe_noise_sd, lower_is_better=True + ), minimize=True, ) ) - - name = "Jenatton" + ("_observed_noise" if observe_noise_sd else "") - return BenchmarkProblem( name=name, search_space=search_space, optimization_config=optimization_config, - runner=SyntheticRunner(), + runner=ParamBasedTestProblemRunner( + test_problem_class=Jenatton, + test_problem_kwargs={"noise_std": noise_std}, + outcome_names=[name], + ), num_trials=num_trials, - is_noiseless=True, + is_noiseless=noise_std == 0.0, observe_noise_stds=observe_noise_sd, has_ground_truth=True, - optimal_value=0.1, + optimal_value=Jenatton.optimal_value, ) diff --git a/ax/benchmark/tests/metrics/test_jennaton.py b/ax/benchmark/tests/metrics/test_jennaton.py index f7cb2474a13..ab4933e3c0e 100644 --- a/ax/benchmark/tests/metrics/test_jennaton.py +++ b/ax/benchmark/tests/metrics/test_jennaton.py @@ -7,107 +7,217 @@ import math from random import random -from unittest import mock -from ax.benchmark.metrics.jenatton import jenatton_test_function, JenattonMetric +from ax.benchmark.metrics.benchmark import BenchmarkMetric, GroundTruthBenchmarkMetric + +from ax.benchmark.metrics.jenatton import jenatton_test_function +from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem +from ax.benchmark.runners.base import BenchmarkRunner +from ax.benchmark.runners.botorch_test import ParamBasedTestProblemRunner from ax.core.arm import Arm +from ax.core.data import Data +from ax.core.experiment import Experiment from ax.core.trial import Trial +from ax.core.types import TParameterization from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance -class JenattonMetricTest(TestCase): +class JenattonTest(TestCase): def test_jenatton_test_function(self) -> None: + benchmark_problem = get_jenatton_benchmark_problem() + rand_params = {f"x{i}": random() for i in range(4, 8)} rand_params["r8"] = random() rand_params["r9"] = random() + cases: list[tuple[TParameterization, float]] = [] + for x3 in (0, 1): - self.assertAlmostEqual( - jenatton_test_function( - x1=0, - x2=0, - x3=x3, - **{**rand_params, "x4": 2.0, "r8": 0.05}, + # list of (param dict, expected value) + cases.append( + ( + { + "x1": 0, + "x2": 0, + "x3": x3, + **{**rand_params, "x4": 2.0, "r8": 0.05}, + }, + 4.15, ), - 4.15, ) - self.assertAlmostEqual( - jenatton_test_function( - x1=0, - x2=1, - x3=x3, - **{**rand_params, "x5": 2.0, "r8": 0.05}, - ), - 4.25, + cases.append( + ( + { + "x1": 0, + "x2": 1, + "x3": x3, + **{**rand_params, "x5": 2.0, "r8": 0.05}, + }, + 4.25, + ) ) + for x2 in (0, 1): + cases.append( + ( + { + "x1": 1, + "x2": x2, + "x3": 0, + **{**rand_params, "x6": 2.0, "r9": 0.05}, + }, + 4.35, + ) + ) + cases.append( + ( + { + "x1": 1, + "x2": x2, + "x3": 1, + **{**rand_params, "x7": 2.0, "r9": 0.05}, + }, + 4.45, + ) + ) + + for params, value in cases: + arm = Arm(parameters=params) self.assertAlmostEqual( - jenatton_test_function( - x1=1, - x2=x2, - x3=0, - **{**rand_params, "x6": 2.0, "r9": 0.05}, - ), - 4.35, + # pyre-fixme: Incompatible parameter type [6]: In call + # `jenatton_test_function`, for 1st positional argument, + # expected `Optional[float]` but got `Union[None, bool, float, + # int, str]`. + jenatton_test_function(**params), + value, ) self.assertAlmostEqual( - jenatton_test_function( - x1=1, - x2=x2, - x3=1, - **{**rand_params, "x7": 2.0, "r9": 0.05}, - ), - 4.45, + assert_is_instance(benchmark_problem.runner, BenchmarkRunner) + .get_Y_true(arm) + .item(), + value, + places=6, ) - def test_init(self) -> None: - metric = JenattonMetric() - self.assertEqual(metric.name, "jenatton") + def test_create_problem(self) -> None: + problem = get_jenatton_benchmark_problem() + objective = problem.optimization_config.objective + metric = objective.metric + + self.assertEqual(metric.name, "Jenatton") + self.assertTrue(objective.minimize) self.assertTrue(metric.lower_is_better) - self.assertEqual(metric.noise_std, 0.0) - self.assertFalse(metric.observe_noise_sd) - metric = JenattonMetric(name="nottanej", noise_std=0.1, observe_noise_sd=True) - self.assertEqual(metric.name, "nottanej") + self.assertEqual( + assert_is_instance( + problem.runner, ParamBasedTestProblemRunner + ).test_problem.noise_std, + 0.0, + ) + self.assertTrue(problem.is_noiseless) + # TODO: make every problem's metrics be BenchmarkMetrics + self.assertFalse(assert_is_instance(metric, BenchmarkMetric).observe_noise_sd) + + problem = get_jenatton_benchmark_problem( + num_trials=10, noise_std=0.1, observe_noise_sd=True + ) + objective = problem.optimization_config.objective + metric = objective.metric self.assertTrue(metric.lower_is_better) - self.assertEqual(metric.noise_std, 0.1) - self.assertTrue(metric.observe_noise_sd) + self.assertEqual( + assert_is_instance( + problem.runner, ParamBasedTestProblemRunner + ).test_problem.noise_std, + 0.1, + ) + self.assertFalse(problem.is_noiseless) + self.assertTrue(assert_is_instance(metric, BenchmarkMetric).observe_noise_sd) def test_fetch_trial_data(self) -> None: - arm = mock.Mock(spec=Arm) - arm.parameters = {"x1": 0, "x2": 1, "x5": 2.0, "r8": 0.05} - trial = mock.Mock(spec=Trial) - trial.arms_by_name = {"0_0": arm} - trial.index = 0 - - metric = JenattonMetric() - df = metric.fetch_trial_data(trial=trial).value.df # pyre-ignore [16] + problem = get_jenatton_benchmark_problem() + arm = Arm(parameters={"x1": 0, "x2": 1, "x5": 2.0, "r8": 0.05}, name="0_0") + + experiment = Experiment( + search_space=problem.search_space, + name="Jenatton", + optimization_config=problem.optimization_config, + ) + + trial = Trial(experiment=experiment) + trial.add_arm(arm) + metadata = problem.runner.run(trial=trial) + trial.update_run_metadata(metadata) + + expected_metadata = { + "Ys": {"0_0": [4.25]}, + "Ystds": {"0_0": [0.0]}, + "outcome_names": ["Jenatton"], + "Ys_true": {"0_0": [4.25]}, + } + self.assertEqual(metadata, expected_metadata) + + metric = problem.optimization_config.objective.metric + + df = assert_is_instance(metric.fetch_trial_data(trial=trial).value, Data).df self.assertEqual(len(df), 1) res_dict = df.iloc[0].to_dict() self.assertEqual(res_dict["arm_name"], "0_0") - self.assertEqual(res_dict["metric_name"], "jenatton") + self.assertEqual(res_dict["metric_name"], "Jenatton") self.assertEqual(res_dict["mean"], 4.25) self.assertTrue(math.isnan(res_dict["sem"])) self.assertEqual(res_dict["trial_index"], 0) - metric = JenattonMetric(name="nottanej", noise_std=0.1, observe_noise_sd=True) - df = metric.fetch_trial_data(trial=trial).value.df # pyre-ignore [16] + problem = get_jenatton_benchmark_problem(noise_std=0.1, observe_noise_sd=True) + experiment = Experiment( + search_space=problem.search_space, + name="Jenatton", + optimization_config=problem.optimization_config, + ) + + trial = Trial(experiment=experiment) + trial.add_arm(arm) + metadata = problem.runner.run(trial=trial) + trial.update_run_metadata(metadata) + + metric = problem.optimization_config.objective.metric + df = assert_is_instance(metric.fetch_trial_data(trial=trial).value, Data).df self.assertEqual(len(df), 1) res_dict = df.iloc[0].to_dict() self.assertEqual(res_dict["arm_name"], "0_0") - self.assertEqual(res_dict["metric_name"], "nottanej") self.assertNotEqual(res_dict["mean"], 4.25) - self.assertEqual(res_dict["sem"], 0.1) + self.assertAlmostEqual(res_dict["sem"], 0.1) self.assertEqual(res_dict["trial_index"], 0) def test_make_ground_truth_metric(self) -> None: - metric = JenattonMetric() - gt_metric = metric.make_ground_truth_metric() - self.assertIsInstance(gt_metric, JenattonMetric) - self.assertEqual(gt_metric.noise_std, 0.0) - self.assertFalse(gt_metric.observe_noise_sd) - metric = JenattonMetric(noise_std=0.1, observe_noise_sd=True) + problem = get_jenatton_benchmark_problem() + + arm = Arm(parameters={"x1": 0, "x2": 1, "x5": 2.0, "r8": 0.05}, name="0_0") + + experiment = Experiment( + search_space=problem.search_space, + name="Jenatton", + optimization_config=problem.optimization_config, + ) + + trial = Trial(experiment=experiment) + trial.add_arm(arm) + problem.runner.run(trial=trial) + metadata = problem.runner.run(trial=trial) + trial.update_run_metadata(metadata) + + metric = assert_is_instance( + problem.optimization_config.objective.metric, BenchmarkMetric + ) gt_metric = metric.make_ground_truth_metric() - self.assertIsInstance(gt_metric, JenattonMetric) - self.assertEqual(gt_metric.noise_std, 0.0) - self.assertFalse(gt_metric.observe_noise_sd) + self.assertIsInstance(gt_metric, GroundTruthBenchmarkMetric) + runner = assert_is_instance(problem.runner, ParamBasedTestProblemRunner) + self.assertEqual(runner.test_problem.noise_std, 0.0) + self.assertFalse( + assert_is_instance(gt_metric, BenchmarkMetric).observe_noise_sd + ) + + self.assertIsInstance(metric, BenchmarkMetric) + self.assertNotIsInstance(metric, GroundTruthBenchmarkMetric) + self.assertEqual(runner.test_problem.noise_std, 0.0) + self.assertFalse(metric.observe_noise_sd) diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 3d0ae2eeda3..366c5611122 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -278,16 +278,22 @@ def test_create_benchmark_experiment(self) -> None: def test_replication_sobol_synthetic(self) -> None: method = get_sobol_benchmark_method() - problem = get_single_objective_benchmark_problem() - res = benchmark_replication(problem=problem, method=method, seed=0) + problems = [ + get_single_objective_benchmark_problem(), + get_problem("jenatton", num_trials=6), + ] + for problem in problems: + res = benchmark_replication(problem=problem, method=method, seed=0) - self.assertEqual( - min(problem.num_trials, not_none(method.scheduler_options.total_trials)), - len(not_none(res.experiment).trials), - ) + self.assertEqual( + min( + problem.num_trials, not_none(method.scheduler_options.total_trials) + ), + len(not_none(res.experiment).trials), + ) - self.assertTrue(np.isfinite(res.score_trace).all()) - self.assertTrue(np.all(res.score_trace <= 100)) + self.assertTrue(np.isfinite(res.score_trace).all()) + self.assertTrue(np.all(res.score_trace <= 100)) def test_replication_sobol_surrogate(self) -> None: method = get_sobol_benchmark_method() diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index af8fd720538..12d2ebde923 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -17,7 +17,6 @@ ) from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult from ax.benchmark.metrics.benchmark import BenchmarkMetric, GroundTruthBenchmarkMetric -from ax.benchmark.metrics.jenatton import JenattonMetric from ax.benchmark.problems.hpo.pytorch_cnn import PyTorchCNNMetric from ax.benchmark.problems.hpo.torchvision import ( PyTorchCNNTorchvisionBenchmarkProblem, @@ -213,7 +212,6 @@ Hartmann6Metric: metric_to_dict, ImprovementGlobalStoppingStrategy: improvement_global_stopping_strategy_to_dict, Interval: botorch_component_to_dict, - JenattonMetric: metric_to_dict, L2NormMetric: metric_to_dict, LogNormalPrior: botorch_component_to_dict, MapData: map_data_to_dict, @@ -337,7 +335,6 @@ "HierarchicalSearchSpace": HierarchicalSearchSpace, "ImprovementGlobalStoppingStrategy": ImprovementGlobalStoppingStrategy, "Interval": Interval, - "JenattonMetric": JenattonMetric, "LifecycleStage": LifecycleStage, "ListSurrogate": Surrogate, # For backwards compatibility "L2NormMetric": L2NormMetric, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 02e84bc9326..b53139ea3e5 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -13,7 +13,7 @@ import numpy as np import torch -from ax.benchmark.metrics.jenatton import JenattonMetric +from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem from ax.core.metric import Metric from ax.core.objective import Objective from ax.core.runner import Runner @@ -192,7 +192,6 @@ ("HierarchicalSearchSpace", get_hierarchical_search_space), ("ImprovementGlobalStoppingStrategy", get_improvement_global_stopping_strategy), ("Interval", get_interval), - ("JenattonMetric", JenattonMetric), ("MapData", get_map_data), ("MapData", get_map_data), ("MapKeyInfo", get_map_key_info), @@ -209,6 +208,7 @@ ("OrderConstraint", get_order_constraint), ("OutcomeConstraint", get_outcome_constraint), ("Path", get_pathlib_path), + ("Jenatton", get_jenatton_benchmark_problem), ("PercentileEarlyStoppingStrategy", get_percentile_early_stopping_strategy), ( "PercentileEarlyStoppingStrategy",