Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different noise levels for different outputs in test functions #2136

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions botorch/test_functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from typing import List, Tuple, Union

import torch

from botorch.exceptions.errors import InputDataError
from torch import Tensor
from torch.nn import Module
Expand All @@ -26,11 +27,17 @@ class BaseTestProblem(Module, ABC):
_bounds: List[Tuple[float, float]]
_check_grad_at_opt: bool = True

def __init__(self, noise_std: Optional[float] = None, negate: bool = False) -> None:
def __init__(
self,
noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
) -> None:
r"""Base constructor for test functions.

Args:
noise_std: Standard deviation of the observation noise.
noise_std: Standard deviation of the observation noise. If a list is
provided, specifies separate noise standard deviations for each
objective in a multiobjective problem.
negate: If True, negate the function.
"""
super().__init__()
Expand Down Expand Up @@ -60,7 +67,8 @@ def forward(self, X: Tensor, noise: bool = True) -> Tensor:
X = X if batch else X.unsqueeze(0)
f = self.evaluate_true(X=X)
if noise and self.noise_std is not None:
f += self.noise_std * torch.randn_like(f)
_noise = torch.tensor(self.noise_std, device=X.device, dtype=X.dtype)
f += _noise * torch.randn_like(f)
if self.negate:
f = -f
return f if batch else f.squeeze(0)
Expand All @@ -82,6 +90,7 @@ class ConstrainedBaseTestProblem(BaseTestProblem, ABC):

num_constraints: int
_check_grad_at_opt: bool = False
constraint_noise_std: Union[None, float, List[float]] = None

def evaluate_slack(self, X: Tensor, noise: bool = True) -> Tensor:
r"""Evaluate the constraint slack on a set of points.
Expand All @@ -101,10 +110,11 @@ def evaluate_slack(self, X: Tensor, noise: bool = True) -> Tensor:
corresponds to the constraint being feasible).
"""
cons = self.evaluate_slack_true(X=X)
if noise and self.noise_std is not None:
# TODO: Allow different noise levels for objective and constraints (and
# different noise levels between different constraints)
cons += self.noise_std * torch.randn_like(cons)
if noise and self.constraint_noise_std is not None:
_constraint_noise = torch.tensor(
self.constraint_noise_std, device=X.device, dtype=X.dtype
)
cons += _constraint_noise * torch.randn_like(cons)
return cons

def is_feasible(self, X: Tensor, noise: bool = True) -> Tensor:
Expand Down Expand Up @@ -147,13 +157,24 @@ class MultiObjectiveTestProblem(BaseTestProblem):
_ref_point: List[float]
_max_hv: float

def __init__(self, noise_std: Optional[float] = None, negate: bool = False) -> None:
def __init__(
self,
noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
) -> None:
r"""Base constructor for multi-objective test functions.

Args:
noise_std: Standard deviation of the observation noise.
noise_std: Standard deviation of the observation noise. If a list is
provided, specifies separate noise standard deviations for each
objective.
negate: If True, negate the objectives.
"""
if isinstance(noise_std, list) and len(noise_std) != len(self._ref_point):
raise InputDataError(
f"If specified as a list, length of noise_std ({len(noise_std)}) "
f"must match the number of objectives ({len(self._ref_point)})"
)
super().__init__(noise_std=noise_std, negate=negate)
ref_point = torch.tensor(self._ref_point, dtype=torch.float)
if negate:
Expand Down
24 changes: 16 additions & 8 deletions botorch/test_functions/multi_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
import math
from abc import ABC, abstractmethod
from math import pi
from typing import Optional
from typing import List, Union

import torch
from botorch.exceptions.errors import UnsupportedError
Expand Down Expand Up @@ -116,7 +116,11 @@ class BraninCurrin(MultiObjectiveTestProblem):
_ref_point = [18.0, 6.0]
_max_hv = 59.36011874867746 # this is approximated using NSGA-II

def __init__(self, noise_std: Optional[float] = None, negate: bool = False) -> None:
def __init__(
self,
noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
) -> None:
r"""
Args:
noise_std: Standard deviation of the observation noise.
Expand Down Expand Up @@ -174,7 +178,7 @@ class DH(MultiObjectiveTestProblem, ABC):
def __init__(
self,
dim: int,
noise_std: Optional[float] = None,
noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
) -> None:
r"""
Expand Down Expand Up @@ -334,7 +338,7 @@ def __init__(
self,
dim: int,
num_objectives: int = 2,
noise_std: Optional[float] = None,
noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
) -> None:
r"""
Expand Down Expand Up @@ -600,7 +604,7 @@ class GMM(MultiObjectiveTestProblem):

def __init__(
self,
noise_std: Optional[float] = None,
noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
num_objectives: int = 2,
) -> None:
Expand Down Expand Up @@ -926,7 +930,7 @@ def __init__(
self,
dim: int,
num_objectives: int = 2,
noise_std: Optional[float] = None,
noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
) -> None:
r"""
Expand Down Expand Up @@ -1234,7 +1238,11 @@ class ConstrainedBraninCurrin(BraninCurrin, ConstrainedBaseTestProblem):
_ref_point = [80.0, 12.0]
_max_hv = 608.4004237022673 # from NSGA-II with 90k evaluations

def __init__(self, noise_std: Optional[float] = None, negate: bool = False) -> None:
def __init__(
self,
noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
) -> None:
r"""
Args:
noise_std: Standard deviation of the observation noise.
Expand Down Expand Up @@ -1337,7 +1345,7 @@ class MW7(MultiObjectiveTestProblem, ConstrainedBaseTestProblem):
def __init__(
self,
dim: int,
noise_std: Optional[float] = None,
noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
) -> None:
r"""
Expand Down
119 changes: 109 additions & 10 deletions botorch/test_functions/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@
from __future__ import annotations

import math
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
from botorch.exceptions.errors import InputDataError
from botorch.test_functions.base import BaseTestProblem, ConstrainedBaseTestProblem
from botorch.test_functions.utils import round_nearest
from torch import Tensor
Expand All @@ -64,13 +65,15 @@ class SyntheticTestFunction(BaseTestProblem):

def __init__(
self,
noise_std: Optional[float] = None,
noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
bounds: Optional[List[Tuple[float, float]]] = None,
) -> None:
r"""
Args:
noise_std: Standard deviation of the observation noise.
noise_std: Standard deviation of the observation noise. If a list is
provided, specifies separate noise standard deviations for each
objective in a multiobjective problem.
negate: If True, negate the function.
bounds: Custom bounds for the function specified as (lower, upper) pairs.
"""
Expand Down Expand Up @@ -802,7 +805,57 @@ def evaluate_true(self, X: Tensor) -> Tensor:
# ------------ Constrained synthetic test functions ----------- #


class ConstrainedGramacy(ConstrainedBaseTestProblem, SyntheticTestFunction):
class ConstrainedSyntheticTestFunction(
ConstrainedBaseTestProblem, SyntheticTestFunction
):
r"""Base class for constrained synthetic test functions."""

def __init__(
self,
noise_std: Union[None, float, List[float]] = None,
constraint_noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
bounds: Optional[List[Tuple[float, float]]] = None,
) -> None:
r"""
Args:
noise_std: Standard deviation of the observation noise. If a list is
provided, specifies separate noise standard deviations for each
objective in a multiobjective problem.
constraint_noise_std: Standard deviation of the constraint noise.
If a list is provided, specifies separate noise standard
deviations for each constraint.
negate: If True, negate the function.
bounds: Custom bounds for the function specified as (lower, upper) pairs.
"""
self.setup_constraint_noise(constraint_noise_std)
SyntheticTestFunction.__init__(
self, noise_std=noise_std, negate=negate, bounds=bounds
)

def setup_constraint_noise(self, constraint_noise_std):
"""
Validates that constraint_noise_std has length equal to
the number of constraints, if given as a list

Args:
constraint_noise_std: Standard deviation of the constraint noise.
If a list is provided, specifies separate noise standard
deviations for each constraint.
"""
if (
isinstance(constraint_noise_std, list)
and len(constraint_noise_std) != self.num_constraints
):
raise InputDataError(
f"If specified as a list, length of constraint_noise_std "
f"({len(constraint_noise_std)}) must match the "
f"number of constraints ({self.num_constraints})"
)
self.constraint_noise_std = constraint_noise_std


class ConstrainedGramacy(ConstrainedSyntheticTestFunction):
r"""Constrained Gramacy test function.

This problem comes from [Gramacy2016]_. The problem is defined
Expand Down Expand Up @@ -835,31 +888,77 @@ def evaluate_slack_true(self, X: Tensor) -> Tensor:
return torch.cat([-c1, -c2], dim=-1)


class ConstrainedHartmann(Hartmann, ConstrainedBaseTestProblem):
class ConstrainedHartmann(Hartmann, ConstrainedSyntheticTestFunction):
r"""Constrained Hartmann test function.

This is a constrained version of the standard Hartmann test function that
uses `||x||_2 <= 1` as the constraint. This problem comes from [Letham2019]_.
"""
num_constraints = 1

def __init__(
self,
dim: int = 6,
noise_std: Union[None, float] = None,
constraint_noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
bounds: Optional[List[Tuple[float, float]]] = None,
) -> None:
r"""
Args:
dim: The (input) dimension.
noise_std: Standard deviation of the observation noise.
constraint_noise_std: Standard deviation of the constraint noise.
If a list is provided, specifies separate noise standard
deviations for each constraint.
negate: If True, negate the function.
bounds: Custom bounds for the function specified as (lower, upper) pairs.
"""
self.setup_constraint_noise(constraint_noise_std)
Hartmann.__init__(
self, dim=dim, noise_std=noise_std, negate=negate, bounds=bounds
)

def evaluate_slack_true(self, X: Tensor) -> Tensor:
return -X.norm(dim=-1, keepdim=True) + 1


class ConstrainedHartmannSmooth(Hartmann, ConstrainedBaseTestProblem):
class ConstrainedHartmannSmooth(Hartmann, ConstrainedSyntheticTestFunction):
r"""Smooth constrained Hartmann test function.

This is a constrained version of the standard Hartmann test function that
uses `||x||_2^2 <= 1` as the constraint to obtain smoother constraint slack.
"""
num_constraints = 1

def __init__(
self,
dim: int = 6,
noise_std: Union[None, float] = None,
constraint_noise_std: Union[None, float, List[float]] = None,
negate: bool = False,
bounds: Optional[List[Tuple[float, float]]] = None,
) -> None:
r"""
Args:
dim: The (input) dimension.
noise_std: Standard deviation of the observation noise.
constraint_noise_std: Standard deviation of the constraint noise.
If a list is provided, specifies separate noise standard
deviations for each constraint.
negate: If True, negate the function.
bounds: Custom bounds for the function specified as (lower, upper) pairs.
"""
self.setup_constraint_noise(constraint_noise_std)
Hartmann.__init__(
self, dim=dim, noise_std=noise_std, negate=negate, bounds=bounds
)

def evaluate_slack_true(self, X: Tensor) -> Tensor:
return -X.pow(2).sum(dim=-1, keepdim=True) + 1


class PressureVessel(SyntheticTestFunction, ConstrainedBaseTestProblem):
class PressureVessel(ConstrainedSyntheticTestFunction):
r"""Pressure vessel design problem with constraints.

The four-dimensional pressure vessel design problem with four black-box
Expand Down Expand Up @@ -894,7 +993,7 @@ def evaluate_slack_true(self, X: Tensor) -> Tensor:
)


class WeldedBeamSO(SyntheticTestFunction, ConstrainedBaseTestProblem):
class WeldedBeamSO(ConstrainedSyntheticTestFunction):
r"""Welded beam design problem with constraints (single-outcome).

The four-dimensional welded beam design proble problem with six
Expand Down Expand Up @@ -950,7 +1049,7 @@ def evaluate_slack_true(self, X: Tensor) -> Tensor:
return -torch.stack([g1, g2, g3, g4, g5, g6], dim=-1)


class TensionCompressionString(SyntheticTestFunction, ConstrainedBaseTestProblem):
class TensionCompressionString(ConstrainedSyntheticTestFunction):
r"""Tension compression string optimization problem with constraints.

The three-dimensional tension compression string optimization problem with
Expand Down Expand Up @@ -981,7 +1080,7 @@ def evaluate_slack_true(self, X: Tensor) -> Tensor:
return -constraints.clamp_max(100)


class SpeedReducer(SyntheticTestFunction, ConstrainedBaseTestProblem):
class SpeedReducer(ConstrainedSyntheticTestFunction):
r"""Speed Reducer design problem with constraints.

The seven-dimensional speed reducer design problem with eleven black-box
Expand Down
Loading