Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable inference regret #2782

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 71 additions & 6 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from collections.abc import Iterable
from itertools import product
from logging import Logger
from time import time
from time import monotonic, time

import numpy as np

from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult
from ax.core.experiment import Experiment
from ax.core.types import TParameterization
from ax.core.utils import get_model_times
from ax.service.scheduler import Scheduler
from ax.service.utils.best_point_mixin import BestPointMixin
Expand Down Expand Up @@ -93,12 +94,23 @@ def benchmark_replication(
method: BenchmarkMethod,
seed: int,
) -> BenchmarkResult:
"""Runs one benchmarking replication (equivalent to one optimization loop).
"""
Run one benchmarking replication (equivalent to one optimization loop).

After each trial, the `method` gets the best parameter(s) found so far, as
evaluated based on empirical data. After all trials are run, the `problem`
gets the oracle values of each "best" parameter; this yields the ``inference
trace``. The cumulative maximum of the oracle value of each parameterization
tested is the ``oracle_trace``.


Args:
problem: The BenchmarkProblem to test against (can be synthetic or real)
method: The BenchmarkMethod to test
seed: The seed to use for this replication.

Return:
``BenchmarkResult`` object.
"""

experiment = Experiment(
Expand All @@ -113,19 +125,70 @@ def benchmark_replication(
generation_strategy=method.generation_strategy.clone_reset(),
options=method.scheduler_options,
)
timeout_hours = scheduler.options.timeout_hours

# list of parameters for each trial
best_params_by_trial: list[list[TParameterization]] = []

is_mf_or_mt = len(problem.runner.target_fidelity_and_task) > 0
# Run the optimization loop.
timeout_hours = scheduler.options.timeout_hours
with with_rng_seed(seed=seed):
scheduler.run_n_trials(max_trials=problem.num_trials)
start = monotonic()
for _ in range(problem.num_trials):
next(
scheduler.run_trials_and_yield_results(
max_trials=1, timeout_hours=timeout_hours
)
)
if timeout_hours is not None:
elapsed_hours = (monotonic() - start) / 3600
timeout_hours = timeout_hours - elapsed_hours
if timeout_hours <= 0:
break

if problem.is_moo or is_mf_or_mt:
# Inference trace is not supported for MOO.
# It's also not supported for multi-fidelity or multi-task
# problems, because Ax's best-point functionality doesn't know
# to predict at the target task or fidelity.
continue

best_params = method.get_best_parameters(
experiment=experiment,
optimization_config=problem.optimization_config,
n_points=problem.n_best_points,
)
best_params_by_trial.append(best_params)

# Construct inference trace from best parameters
inference_trace = np.full(problem.num_trials, np.nan)
for trial_index, best_params in enumerate(best_params_by_trial):
if len(best_params) == 0:
inference_trace[trial_index] = np.nan
continue
# Construct an experiment with one BatchTrial
best_params_oracle_experiment = problem.get_oracle_experiment_from_params(
{0: {str(i): p for i, p in enumerate(best_params)}}
)
# Get the optimization trace. It will have only one point.
inference_trace[trial_index] = BestPointMixin._get_trace(
experiment=best_params_oracle_experiment,
optimization_config=problem.optimization_config,
)[0]

oracle_experiment = problem.get_oracle_experiment_from_experiment(
actual_params_oracle_experiment = problem.get_oracle_experiment_from_experiment(
experiment=experiment
)
optimization_trace = np.array(
oracle_trace = np.array(
BestPointMixin._get_trace(
experiment=oracle_experiment,
experiment=actual_params_oracle_experiment,
optimization_config=problem.optimization_config,
)
)
optimization_trace = (
inference_trace if problem.report_inference_value_as_trace else oracle_trace
)

try:
# Catch any errors that may occur during score computation, such as errors
Expand Down Expand Up @@ -155,6 +218,8 @@ def benchmark_replication(
name=scheduler.experiment.name,
seed=seed,
experiment=scheduler.experiment,
oracle_trace=oracle_trace,
inference_trace=inference_trace,
optimization_trace=optimization_trace,
score_trace=score_trace,
fit_time=fit_time,
Expand Down
82 changes: 76 additions & 6 deletions ax/benchmark/benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@

# pyre-strict

import logging
from dataclasses import dataclass
from dataclasses import dataclass, field

from ax.core.experiment import Experiment
from ax.core.optimization_config import (
MultiObjectiveOptimizationConfig,
OptimizationConfig,
)
from ax.core.types import TParameterization

from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.scheduler_options import SchedulerOptions, TrialType
from ax.utils.common.base import Base
from ax.utils.common.logger import get_logger


logger: logging.Logger = get_logger("BenchmarkMethod")
from pyre_extensions import none_throws


@dataclass(frozen=True)
Expand All @@ -36,12 +40,78 @@ class BenchmarkMethod(Base):
`get_benchmark_scheduler_options`.
distribute_replications: Indicates whether the replications should be
run in a distributed manner. Ax itself does not use this attribute.
best_point_kwargs: Arguments passed to `get_pareto_optimal_parameters`
(if multi-objective) or `BestPointMixin._get_best_trial` (if
single-objective). Currently, the only supported argument is
`use_model_predictions`. However, note that if multi-objective,
best-point selection is not currently supported and
`get_pareto_optimal_parameters` will raise a `NotImplementedError`.
"""

name: str
generation_strategy: GenerationStrategy
scheduler_options: SchedulerOptions
distribute_replications: bool = False
best_point_kwargs: dict[str, bool] = field(
default_factory=lambda: {"use_model_predictions": False}
)

def get_best_parameters(
self,
experiment: Experiment,
optimization_config: OptimizationConfig,
n_points: int,
) -> list[TParameterization]:
"""
Get ``n_points`` promising points. NOTE: Only SOO with n_points = 1 is
supported.

The expected use case is that these points will be evaluated against an
oracle for hypervolume (if multi-objective) or for the value of the best
parameter (if single-objective).

For multi-objective cases, ``n_points > 1`` is needed. For SOO, ``n_points > 1``
reflects setups where we can choose some points which will then be
evaluated noiselessly or at high fidelity and then use the best one.


Args:
experiment: The experiment to get the data from. This should contain
values that would be observed in a realistic setting and not
contain oracle values.
optimization_config: The ``optimization_config`` for the corresponding
``BenchmarkProblem``.
n_points: The number of points to return.
"""
if isinstance(optimization_config, MultiObjectiveOptimizationConfig):
raise NotImplementedError(
"BenchmarkMethod.get_pareto_optimal_parameters is not currently "
"supported for multi-objective problems."
)

if n_points != 1:
raise NotImplementedError(
f"Currently only n_points=1 is supported. Got {n_points=}."
)

# SOO, n=1 case.
# Note: This has the same effect Scheduler.get_best_parameters
result = BestPointMixin._get_best_trial(
experiment=experiment,
generation_strategy=self.generation_strategy,
optimization_config=optimization_config,
# pyre-fixme: Incompatible parameter type [6]: In call
# `get_pareto_optimal_parameters`, for 4th positional argument,
# expected `Optional[Iterable[int]]` but got `bool`.
**self.best_point_kwargs,
)
if result is None:
# This can happen if no points are predicted to satisfy all outcome
# constraints.
return []

i, params, prediction = none_throws(result)
return [params]


def get_benchmark_scheduler_options(
Expand Down
24 changes: 24 additions & 0 deletions ax/benchmark/benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ class BenchmarkProblem(Base):
search_space: The search space.
runner: The Runner that will be used to generate data for the problem,
including any ground-truth data stored as tracking metrics.
report_inference_value_as_trace: Whether the ``optimization_trace`` on a
``BenchmarkResult`` should use the ``oracle_trace`` (if False,
default) or the ``inference_trace``. See ``BenchmarkResult`` for
more information. Currently, this is only supported for
single-objective problems.
n_best_points: Number of points for a best-point selector to recommend.
Currently, only ``n_best_points=1`` is supported.
"""

name: str
Expand All @@ -84,6 +91,17 @@ class BenchmarkProblem(Base):

search_space: SearchSpace = field(repr=False)
runner: BenchmarkRunner = field(repr=False)
report_inference_value_as_trace: bool = False
n_best_points: int = 1

def __post_init__(self) -> None:
if self.n_best_points != 1:
raise NotImplementedError("Only `n_best_points=1` is currently supported.")
if self.report_inference_value_as_trace and self.is_moo:
raise NotImplementedError(
"Inference trace is not supported for MOO. Please set "
"`report_inference_value_as_trace` to False."
)

def get_oracle_experiment_from_params(
self,
Expand Down Expand Up @@ -285,6 +303,7 @@ def create_problem_from_botorch(
lower_is_better: bool = True,
observe_noise_sd: bool = False,
search_space: SearchSpace | None = None,
report_inference_value_as_trace: bool = False,
) -> BenchmarkProblem:
"""
Create a `BenchmarkProblem` from a BoTorch `BaseTestProblem`.
Expand All @@ -308,6 +327,10 @@ def create_problem_from_botorch(
search_space: If provided, the `search_space` of the `BenchmarkProblem`.
Otherwise, a `SearchSpace` with all `RangeParameter`s is created
from the bounds of the test problem.
report_inference_value_as_trace: If True, indicates that the
``optimization_trace`` on a ``BenchmarkResult`` ought to be the
``inference_trace``; otherwise, it will be the ``oracle_trace``.
See ``BenchmarkResult`` for more information.
"""
# pyre-fixme [45]: Invalid class instantiation
test_problem = test_problem_class(**test_problem_kwargs)
Expand Down Expand Up @@ -364,4 +387,5 @@ def create_problem_from_botorch(
num_trials=num_trials,
observe_noise_stds=observe_noise_sd,
optimal_value=optimal_value,
report_inference_value_as_trace=report_inference_value_as_trace,
)
43 changes: 34 additions & 9 deletions ax/benchmark/benchmark_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,38 @@ class BenchmarkResult(Base):
name: Name of the benchmark. Should make it possible to determine the
problem and the method.
seed: Seed used for determinism.
optimization_trace: For single-objective problems, element i of the
optimization trace is the oracle value of the "best" point, computed
after the first i trials have been run. For multi-objective
problems, element i of the optimization trace is the hypervolume of
oracle values at a set of points, also computed after the first i
trials (even if these were ``BatchTrials``). Oracle values are
typically ground-truth (rather than noisy) and evaluated at the
target task and fidelity.

oracle_trace: For single-objective problems, element i of the
optimization trace is the best oracle value of the arms evaluated
after the first i trials. For multi-objective problems, element i
of the optimization trace is the hypervolume of the oracle values of
the arms in the first i trials (which may be ``BatchTrial``s).
Oracle values are typically ground-truth (rather than noisy) and
evaluated at the target task and fidelity.
inference_trace: Inference trace comes from choosing a "best" point
based only on data that would be observable in realistic settings
and then evaluating the oracle value of that point. For
multi-objective problems, we find a Pareto set and evaluate its
hypervolume.

There are several ways of specifying the "best" point: One could
pick the point with the best observed value, or the point with the
best model prediction, and could consider the whole search space,
the set of trials completed so far, etc. How the inference trace is
computed is specified by a best-point selector, which is an
attribute of the `BenchmarkMethod`.

Note: This is not "inference regret", which is a lower-is-better value
that is relative to the best possible value. The inference value
trace is higher-is-better if the problem is a maximization problem
or if the problem is multi-objective (in which case hypervolume is
used). Hence, it is signed the same as ``oracle_trace`` and
``optimization_trace``. ``score_trace`` is higher-is-better and
relative to the optimum.
optimization_trace: Either the ``oracle_trace`` or the
``inference_trace``, depending on whether the ``BenchmarkProblem``
specifies ``report_inference_value``. Having ``optimization_trace``
specified separately is useful when we need just one value to
evaluate how well the benchmark went.
score_trace: The scores associated with the problem, typically either
the optimization_trace or inference_value_trace normalized to a
0-100 scale for comparability between problems.
Expand All @@ -56,6 +79,8 @@ class BenchmarkResult(Base):
name: str
seed: int

oracle_trace: ndarray
inference_trace: ndarray
optimization_trace: ndarray
score_trace: ndarray

Expand Down
3 changes: 3 additions & 0 deletions ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get_sobol_botorch_modular_acquisition(
name: Optional[str] = None,
num_sobol_trials: int = 5,
model_gen_kwargs: Optional[dict[str, Any]] = None,
best_point_kwargs: dict[str, bool] | None = None,
) -> BenchmarkMethod:
"""Get a `BenchmarkMethod` that uses Sobol followed by MBM.

Expand All @@ -64,6 +65,7 @@ def get_sobol_botorch_modular_acquisition(
`BatchTrial`s.
model_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately
to the BoTorch `Model`.
best_point_kwargs: Passed to the created `BenchmarkMethod`.

Example:
>>> # A simple example
Expand Down Expand Up @@ -138,4 +140,5 @@ def get_sobol_botorch_modular_acquisition(
generation_strategy=generation_strategy,
scheduler_options=scheduler_options or get_benchmark_scheduler_options(),
distribute_replications=distribute_replications,
best_point_kwargs={} if best_point_kwargs is None else best_point_kwargs,
)
Loading