Skip to content

Commit

Permalink
Introduce ParamBasedTestProblem for benchmarking (facebook#2675)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2675

Context:

In a future refactor that will enable more flexible and powerful best-point functionality, every BenchmarkProblem's runner will be able to produce an "oracle" value (possibly the ground truth) for any arm, in-sample or not, with a function like `BenchmarkRunner.evaluate_oracle(arm=arm)`, with the problem handling computation and the runner formatting results.  However, the current `BenchmarkRunner` and `BenchmarkMetric` setup currently doesn't cover every benchmark. Consolidating on `BenchmarkRunner` and `BenchmarkMetric` will enable the refactor, make it easier to universalize functionality like handling of constraints, noise, and inference regret, and will also allow for deleting some LOC for more custom problems.

Current `BenchmarkRunner`s only handle problems that can consume tensor-valued arguments: BoTorch synthetic problems and surrogate problems. This isn't a good fit for problems like Jenatton that have a hierarchical search space and can have some parameters not passed. Because Ax always passes parameters and only sometimes represents them as tensors, a `TParameterization` is a more natural abstraction to handle parameters than a tensor.

This PR:
- Introduces `ParamBasedTestProblem`, which is like a BoTorch synthetic test problem but consumes a `TParameterization` rather than a tensor
- Added `ParamBasedProblemRunner`, which shares a base class `SyntheticProblemRunner` and most functionality with  `BotorchTestProblemRunner` (so it is a `BenchmarkRunner` and supports both observed and unboserved noise).

Differential Revision: D60996475
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 20, 2024
1 parent 631b89c commit 5fefcbd
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 118 deletions.
17 changes: 11 additions & 6 deletions ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

# pyre-strict

from abc import ABC, abstractmethod, abstractproperty
from abc import ABC, abstractmethod
from collections.abc import Iterable
from math import sqrt
from typing import Any, Union

import torch
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial

from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.runner import Runner
from ax.core.trial import Trial
Expand Down Expand Up @@ -39,10 +41,7 @@ class BenchmarkRunner(Runner, ABC):
not over-engineer for that before such a use case arrives.
"""

@abstractproperty
def outcome_names(self) -> list[str]:
"""The names of the outcomes of the problem (in the order of the outcomes)."""
pass # pragma: no cover
outcome_names: list[str]

def get_Y_true(self, arm: Arm) -> Tensor:
"""
Expand Down Expand Up @@ -132,3 +131,9 @@ def run(self, trial: BaseTrial) -> dict[str, Any]:
"Ys_true": Ys_true,
}
return run_metadata

# This will need to be udpated once asynchronous benchmarks are supported.
def poll_trial_status(
self, trials: Iterable[BaseTrial]
) -> dict[TrialStatus, set[int]]:
return {TrialStatus.COMPLETED: {t.index for t in trials}}
256 changes: 187 additions & 69 deletions ax/benchmark/runners/botorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,85 @@
# pyre-strict

import importlib
from collections.abc import Iterable
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, Union

import torch
from ax.benchmark.runners.base import BenchmarkRunner
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.types import TParameterization
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from ax.utils.common.typeutils import checked_cast
from botorch.test_functions.base import BaseTestProblem, ConstrainedBaseTestProblem
from botorch.test_functions.multi_objective import MultiObjectiveTestProblem
from botorch.test_functions.synthetic import (
ConstrainedSyntheticTestFunction,
SyntheticTestFunction,
)
from botorch.utils.transforms import normalize, unnormalize
from pyre_extensions import assert_is_instance
from torch import Tensor


class BotorchTestProblemRunner(BenchmarkRunner):
"""A Runner for evaluating Botorch BaseTestProblems.
@dataclass(kw_only=True)
class ParamBasedTestProblem(ABC):
"""
Similar to a BoTorch test problem, but evaluated using an Ax
TParameterization rather than a tensor.
"""

num_objectives: int
optimal_value: float
# Constraints could easily be supported similar to BoTorch test problems,
# but haven't been hooked up.
_is_constrained: bool = False
constraint_noise_std: Optional[Union[float, list[float]]] = None
noise_std: Optional[Union[float, list[float]]] = None
negate: bool = False

@abstractmethod
def evaluate_true(self, params: TParameterization) -> Tensor: ...

def evaluate_slack_true(self, params: TParameterization) -> Tensor:
raise NotImplementedError(
f"{self.__class__.__name__} does not support constraints."
)

# pyre-fixme: Missing parameter annotation [2]: Parameter `other` must have
# a type other than `Any`.
def __eq__(self, other: Any) -> bool:
if not isinstance(other, type(self)):
return False
return self.__class__.__name__ == other.__class__.__name__


Given a trial the Runner will evaluate the BaseTestProblem.forward method for each
arm in the trial, as well as return some metadata about the underlying Botorch
problem such as the noise_std. We compute the full result on the Runner (as opposed
to the Metric as is typical in synthetic test problems) because the BoTorch problem
computes all metrics in one stacked tensor in the MOO case, and we wish to avoid
recomputation per metric.
class SyntheticProblemRunner(BenchmarkRunner, ABC):
"""A Runner for evaluating synthetic problems, either BoTorch
`SyntheticTestFunction`s or Ax benchmarking `ParamBasedTestProblem`s.
Given a trial, the Runner will evaluate the problem noiselessly for each
arm in the trial, as well as return some metadata about the underlying
problem such as the noise_std.
"""

test_problem: BaseTestProblem
test_problem: Union[SyntheticTestFunction, ParamBasedTestProblem]
_is_constrained: bool
_test_problem_class: type[BaseTestProblem]
_test_problem_class: type[Union[SyntheticTestFunction, ParamBasedTestProblem]]
_test_problem_kwargs: Optional[dict[str, Any]]

def __init__(
self,
test_problem_class: type[BaseTestProblem],
*,
test_problem_class: type[Union[SyntheticTestFunction, ParamBasedTestProblem]],
test_problem_kwargs: dict[str, Any],
outcome_names: list[str],
modified_bounds: Optional[list[tuple[float, float]]] = None,
) -> None:
"""Initialize the test problem runner.
Args:
test_problem_class: The BoTorch test problem class.
test_problem_class: A BoTorch `SyntheticTestFunction` class or Ax
`ParamBasedTestProblem` class.
test_problem_kwargs: The keyword arguments used for initializing the
test problem.
outcome_names: The names of the outcomes returned by the problem.
Expand All @@ -63,28 +98,27 @@ def __init__(
If modified bounds are not provided, the test problem will be
evaluated using the raw parameter values.
"""

self._test_problem_class = test_problem_class
self._test_problem_kwargs = test_problem_kwargs

# pyre-fixme [45]: Invalid class instantiation
self.test_problem = test_problem_class(**test_problem_kwargs).to(
dtype=torch.double
self.test_problem = (
# pyre-fixme: Invalid class instantiation [45]: Cannot instantiate
# abstract class with abstract method `evaluate_true`.
test_problem_class(**test_problem_kwargs)
)
if isinstance(self.test_problem, SyntheticTestFunction):
self.test_problem = self.test_problem.to(dtype=torch.double)
# A `ConstrainedSyntheticTestFunction` is a type of `SyntheticTestFunction`; a
# `ParamBasedTestProblem` is never constrained.
self._is_constrained: bool = isinstance(
self.test_problem, ConstrainedBaseTestProblem
self.test_problem, ConstrainedSyntheticTestFunction
)
self._is_moo: bool = isinstance(self.test_problem, MultiObjectiveTestProblem)
self._outcome_names = outcome_names
self._is_moo: bool = self.test_problem.num_objectives > 1
self.outcome_names = outcome_names
self._modified_bounds = modified_bounds

@property
def outcome_names(self) -> list[str]:
return self._outcome_names

@equality_typechecker
def __eq__(self, other: Base) -> bool:
if not isinstance(other, BotorchTestProblemRunner):
if not isinstance(other, type(self)):
return False

return (
Expand Down Expand Up @@ -129,12 +163,95 @@ def get_noise_stds(self) -> Union[None, float, dict[str, float]]:

return noise_std_dict

@classmethod
# pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
"""Serialize the properties needed to initialize the runner.
Used for storage.
"""
runner = assert_is_instance(obj, cls)

return {
"test_problem_module": runner._test_problem_class.__module__,
"test_problem_class_name": runner._test_problem_class.__name__,
"test_problem_kwargs": runner._test_problem_kwargs,
"outcome_names": runner.outcome_names,
"modified_bounds": runner._modified_bounds,
}

@classmethod
def deserialize_init_args(
cls,
args: dict[str, Any],
decoder_registry: Optional[TDecoderRegistry] = None,
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> dict[str, Any]:
"""Given a dictionary, deserialize the properties needed to initialize the
runner. Used for storage.
"""

module = importlib.import_module(args["test_problem_module"])

return {
"test_problem_class": getattr(module, args["test_problem_class_name"]),
"test_problem_kwargs": args["test_problem_kwargs"],
"outcome_names": args["outcome_names"],
"modified_bounds": args["modified_bounds"],
}


class BotorchTestProblemRunner(SyntheticProblemRunner):
"""
A `SyntheticProblemRunner` for BoTorch `SyntheticTestFunction`s.
Args:
test_problem_class: A BoTorch `SyntheticTestFunction` class.
test_problem_kwargs: The keyword arguments used for initializing the
test problem.
outcome_names: The names of the outcomes returned by the problem.
modified_bounds: The bounds that are used by the Ax search space
while optimizing the problem. If different from the bounds of the
test problem, we project the parameters into the test problem
bounds before evaluating the test problem.
For example, if the test problem is defined on [0, 1] but the Ax
search space is integers in [0, 10], an Ax parameter value of
5 will correspond to 0.5 while evaluating the test problem.
If modified bounds are not provided, the test problem will be
evaluated using the raw parameter values.
"""

def __init__(
self,
*,
test_problem_class: type[SyntheticTestFunction],
test_problem_kwargs: dict[str, Any],
outcome_names: list[str],
modified_bounds: Optional[list[tuple[float, float]]] = None,
) -> None:
super().__init__(
test_problem_class=test_problem_class,
test_problem_kwargs=test_problem_kwargs,
outcome_names=outcome_names,
modified_bounds=modified_bounds,
)
self.test_problem: SyntheticTestFunction = self.test_problem.to(
dtype=torch.double
)
self._is_constrained: bool = isinstance(
self.test_problem, ConstrainedSyntheticTestFunction
)

def get_Y_true(self, arm: Arm) -> Tensor:
"""Converts X to original bounds -- only if modified bounds were provided --
and evaluates the test problem. See `__init__` docstring for details.
"""
Convert the arm to a tensor and evaluate it on the base test problem.
Convert the tensor to original bounds -- only if modified bounds were
provided -- and evaluates the test problem. See the docstring for
`modified_bounds` in `BotorchTestProblemRunner.__init__` for details.
Args:
X: A `batch_shape x d`-dim tensor of point(s) at which to evaluate the
arm: Arm to evaluate. It will be converted to a
`batch_shape x d`-dim tensor of point(s) at which to evaluate the
test problem.
Returns:
Expand All @@ -157,7 +274,7 @@ def get_Y_true(self, arm: Arm) -> Tensor:
X = unnormalize(unit_X, self.test_problem.bounds)

Y_true = self.test_problem.evaluate_true(X).view(-1)
# `BaseTestProblem.evaluate_true()` does not negate the outcome
# `SyntheticTestFunction.evaluate_true()` does not negate the outcome
if self.test_problem.negate:
Y_true = -Y_true

Expand All @@ -171,43 +288,44 @@ def get_Y_true(self, arm: Arm) -> Tensor:

return Y_true

def poll_trial_status(
self, trials: Iterable[BaseTrial]
) -> dict[TrialStatus, set[int]]:
return {TrialStatus.COMPLETED: {t.index for t in trials}}

@classmethod
# pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
"""Serialize the properties needed to initialize the runner.
Used for storage.
"""
runner = checked_cast(BotorchTestProblemRunner, obj)
class ParamBasedTestProblemRunner(SyntheticProblemRunner):
"""
A `SyntheticProblemRunner` for `ParamBasedTestProblem`s. See
`SyntheticProblemRunner` for more information.
"""

return {
"test_problem_module": runner._test_problem_class.__module__,
"test_problem_class_name": runner._test_problem_class.__name__,
"test_problem_kwargs": runner._test_problem_kwargs,
"outcome_names": runner._outcome_names,
"modified_bounds": runner._modified_bounds,
}
# This could easily be supported, but hasn't been hooked up
_is_constrained: bool = False

@classmethod
def deserialize_init_args(
cls,
args: dict[str, Any],
decoder_registry: Optional[TDecoderRegistry] = None,
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> dict[str, Any]:
"""Given a dictionary, deserialize the properties needed to initialize the
runner. Used for storage.
"""
def __init__(
self,
*,
test_problem_class: type[ParamBasedTestProblem],
test_problem_kwargs: dict[str, Any],
outcome_names: list[str],
modified_bounds: Optional[list[tuple[float, float]]] = None,
) -> None:
if modified_bounds is not None:
raise NotImplementedError(
f"modified_bounds is not supported for {test_problem_class.__name__}"
)
super().__init__(
test_problem_class=test_problem_class,
test_problem_kwargs=test_problem_kwargs,
outcome_names=outcome_names,
modified_bounds=modified_bounds,
)
self.test_problem: ParamBasedTestProblem = self.test_problem

module = importlib.import_module(args["test_problem_module"])
def get_Y_true(self, arm: Arm) -> Tensor:
"""Evaluates the test problem.
return {
"test_problem_class": getattr(module, args["test_problem_class_name"]),
"test_problem_kwargs": args["test_problem_kwargs"],
"outcome_names": args["outcome_names"],
"modified_bounds": args["modified_bounds"],
}
Returns:
A `batch_shape x m`-dim tensor of ground truth (noiseless) evaluations.
"""
Y_true = self.test_problem.evaluate_true(arm.parameters).view(-1)
# `ParamBasedTestProblem.evaluate_true()` does not negate the outcome
if self.test_problem.negate:
Y_true = -Y_true
return Y_true
Loading

0 comments on commit 5fefcbd

Please sign in to comment.