From f5611064dc1ca258c27bdc678aa32f6c3053617c Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Tue, 24 Sep 2024 09:33:19 -0700 Subject: [PATCH] Parameterization type fixes; use parameters instad of arms where appropriate Summary: Sometimes benchmarking code creates Arms from parameters just so they can be passed to a function that expects an Arm but only uses its parameters. This is silly. It's better to just have the function expect parameters. Also updated some method signatures to use `Mapping` to indicate that they do not mutate the parameterization, which unfortunately creates the need for Pyre-fixmes if they are passed to a function or class such as `Arm` that does not annotate its arguments as immutable. Differential Revision: D63327381 --- .../problems/synthetic/hss/jenatton.py | 5 +++-- ax/benchmark/runners/base.py | 6 +++--- ax/benchmark/runners/botorch_test.py | 20 +++++++++---------- ax/benchmark/runners/surrogate.py | 8 +++++--- .../runners/test_botorch_test_problem.py | 12 +++++------ ax/benchmark/tests/test_benchmark_problem.py | 6 +++--- 6 files changed, 29 insertions(+), 28 deletions(-) diff --git a/ax/benchmark/problems/synthetic/hss/jenatton.py b/ax/benchmark/problems/synthetic/hss/jenatton.py index 1e3c8d5e5ca..30272a7babe 100644 --- a/ax/benchmark/problems/synthetic/hss/jenatton.py +++ b/ax/benchmark/problems/synthetic/hss/jenatton.py @@ -5,6 +5,7 @@ # pyre-strict +from collections.abc import Mapping from dataclasses import dataclass from typing import Optional @@ -19,7 +20,6 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import HierarchicalSearchSpace -from ax.core.types import TParameterization from pyre_extensions import none_throws @@ -60,7 +60,8 @@ class Jenatton(ParamBasedTestProblem): optimal_value: float = 0.1 _is_constrained: bool = False - def evaluate_true(self, params: TParameterization) -> torch.Tensor: + # pyre-fixme[14]: Inconsistent override + def evaluate_true(self, params: Mapping[str, float | int | None]) -> torch.Tensor: # pyre-fixme: Incompatible parameter type [6]: In call # `jenatton_test_function`, for 1st positional argument, expected # `Optional[float]` but got `Union[None, bool, float, int, str]`. diff --git a/ax/benchmark/runners/base.py b/ax/benchmark/runners/base.py index dfa901aa95b..0e5a335221a 100644 --- a/ax/benchmark/runners/base.py +++ b/ax/benchmark/runners/base.py @@ -59,7 +59,7 @@ def __init__(self, search_space_digest: SearchSpaceDigest | None = None) -> None else: self.target_fidelity_and_task = {} - def get_Y_true(self, arm: Arm) -> Tensor: + def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor: """ Return the ground truth values for a given arm. @@ -79,7 +79,7 @@ def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> ndarray: at the true utility function (which would be unobserved in reality). """ params = {**parameters, **self.target_fidelity_and_task} - return self.get_Y_true(arm=Arm(parameters=params)).numpy() + return self.get_Y_true(params=params).numpy() @abstractmethod def get_noise_stds(self) -> Union[None, float, dict[str, float]]: @@ -134,7 +134,7 @@ def run(self, trial: BaseTrial) -> dict[str, Any]: for arm in trial.arms: # Case where we do have a ground truth - Y_true = self.get_Y_true(arm) + Y_true = self.get_Y_true(arm.parameters) if noise_stds is None: # No noise, so just return the true outcome. Ystds[arm.name] = [0.0] * len(Y_true) diff --git a/ax/benchmark/runners/botorch_test.py b/ax/benchmark/runners/botorch_test.py index f45a1524eeb..67b3436ebef 100644 --- a/ax/benchmark/runners/botorch_test.py +++ b/ax/benchmark/runners/botorch_test.py @@ -7,6 +7,7 @@ import importlib from abc import ABC, abstractmethod +from collections.abc import Mapping from dataclasses import dataclass from typing import Any, Optional, Union @@ -14,7 +15,7 @@ from ax.benchmark.runners.base import BenchmarkRunner from ax.core.arm import Arm from ax.core.search_space import SearchSpaceDigest -from ax.core.types import TParameterization +from ax.core.types import TParamValue from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry @@ -41,9 +42,9 @@ class ParamBasedTestProblem(ABC): negate: bool = False @abstractmethod - def evaluate_true(self, params: TParameterization) -> Tensor: ... + def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor: ... - def evaluate_slack_true(self, params: TParameterization) -> Tensor: + def evaluate_slack_true(self, params: Mapping[str, TParamValue]) -> Tensor: raise NotImplementedError( f"{self.__class__.__name__} does not support constraints." ) @@ -243,7 +244,7 @@ def __init__( self.test_problem, ConstrainedBaseTestProblem ) - def get_Y_true(self, arm: Arm) -> Tensor: + def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor: """ Convert the arm to a tensor and evaluate it on the base test problem. @@ -252,7 +253,7 @@ def get_Y_true(self, arm: Arm) -> Tensor: `modified_bounds` in `BotorchTestProblemRunner.__init__` for details. Args: - arm: Arm to evaluate. It will be converted to a + params: Parameterization to evaluate. It will be converted to a `batch_shape x d`-dim tensor of point(s) at which to evaluate the test problem. @@ -260,10 +261,7 @@ def get_Y_true(self, arm: Arm) -> Tensor: A `batch_shape x m`-dim tensor of ground truth (noiseless) evaluations. """ X = torch.tensor( - [ - value - for _key, value in [*arm.parameters.items()][: self.test_problem.dim] - ], + [value for _key, value in [*params.items()][: self.test_problem.dim]], dtype=torch.double, ) @@ -322,13 +320,13 @@ def __init__( ) self.test_problem: ParamBasedTestProblem = self.test_problem - def get_Y_true(self, arm: Arm) -> Tensor: + def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor: """Evaluates the test problem. Returns: A `batch_shape x m`-dim tensor of ground truth (noiseless) evaluations. """ - Y_true = self.test_problem.evaluate_true(arm.parameters).view(-1) + Y_true = self.test_problem.evaluate_true(params).view(-1) # `ParamBasedTestProblem.evaluate_true()` does not negate the outcome if self.test_problem.negate: Y_true = -Y_true diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index 3954d60e17c..fec322303c7 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -6,11 +6,11 @@ # pyre-strict import warnings +from collections.abc import Mapping from typing import Any, Callable, Optional, Union import torch from ax.benchmark.runners.base import BenchmarkRunner -from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.observation import ObservationFeatures from ax.core.search_space import SearchSpace, SearchSpaceDigest @@ -95,11 +95,13 @@ def datasets(self) -> list[SupervisedDataset]: def get_noise_stds(self) -> Union[None, float, dict[str, float]]: return self.noise_stds - def get_Y_true(self, arm: Arm) -> Tensor: + # pyre-fixme[14]: Inconsistent override + def get_Y_true(self, params: Mapping[str, float | int]) -> Tensor: # We're ignoring the uncertainty predictions of the surrogate model here and # use the mean predictions as the outcomes (before potentially adding noise) means, _ = self.surrogate.predict( - observation_features=[ObservationFeatures(arm.parameters)] + # pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict + observation_features=[ObservationFeatures(params)] ) means = [means[name][0] for name in self.outcome_names] return torch.tensor( diff --git a/ax/benchmark/tests/runners/test_botorch_test_problem.py b/ax/benchmark/tests/runners/test_botorch_test_problem.py index cd858395a2c..6b3b5b9fe8f 100644 --- a/ax/benchmark/tests/runners/test_botorch_test_problem.py +++ b/ax/benchmark/tests/runners/test_botorch_test_problem.py @@ -124,11 +124,8 @@ def test_synthetic_runner(self) -> None: with self.subTest(f"test `get_Y_true()`, {test_description}"): X = torch.rand(1, 6, dtype=torch.double) - arm = Arm( - name="0_0", - parameters={f"x{i}": x.item() for i, x in enumerate(X.unbind(-1))}, - ) - Y = runner.get_Y_true(arm=arm) + params = {f"x{i}": x.item() for i, x in enumerate(X.unbind(-1))} + Y = runner.get_Y_true(params=params) if modified_bounds is not None: X_tf = normalize( X, torch.tensor(modified_bounds, dtype=torch.double).T @@ -152,11 +149,14 @@ def test_synthetic_runner(self) -> None: torch.Size([2]), X.pow(2).sum().item(), dtype=torch.double ) self.assertTrue(torch.allclose(Y, expected_Y)) - oracle = runner.evaluate_oracle(parameters=arm.parameters) + oracle = runner.evaluate_oracle(parameters=params) self.assertTrue(np.equal(Y.numpy(), oracle).all()) with self.subTest(f"test `run()`, {test_description}"): trial = Mock(spec=Trial) + # pyre-fixme[6]: Incomptabile parameter type: params is a + # mutable subtype of the type expected by `Arm`. + arm = Arm(name="0_0", parameters=params) trial.arms = [arm] trial.arm = arm trial.index = 0 diff --git a/ax/benchmark/tests/test_benchmark_problem.py b/ax/benchmark/tests/test_benchmark_problem.py index 9e1218a8070..17756fafab7 100644 --- a/ax/benchmark/tests/test_benchmark_problem.py +++ b/ax/benchmark/tests/test_benchmark_problem.py @@ -76,7 +76,7 @@ def _test_multi_fidelity_or_multi_task(self, fidelity_or_task: str) -> None: search_space=SearchSpace(parameters), num_trials=3, ) - arm = Arm(parameters={"x0": 1.0, "x1": 0.0, "x2": 0.0}) + params = {"x0": 1.0, "x1": 0.0, "x2": 0.0} at_target = assert_is_instance( Branin() .evaluate_true(torch.tensor([1.0, 0.0], dtype=torch.double).unsqueeze(0)) @@ -84,7 +84,7 @@ def _test_multi_fidelity_or_multi_task(self, fidelity_or_task: str) -> None: float, ) self.assertAlmostEqual( - problem.runner.evaluate_oracle(parameters=arm.parameters)[0], + problem.runner.evaluate_oracle(parameters=params)[0], at_target, ) # first term: (-(b - 0.1) * (1 - x3) + c - r)^2 @@ -93,7 +93,7 @@ def _test_multi_fidelity_or_multi_task(self, fidelity_or_task: str) -> None: t = -5.1 / (4 * math.pi**2) + 5 / math.pi - 6 expected_change = (t + 0.1) ** 2 - t**2 self.assertAlmostEqual( - problem.runner.get_Y_true(arm=arm).item(), + problem.runner.get_Y_true(params=params).item(), at_target + expected_change, )