Skip to content

Commit

Permalink
Merge: New acquisition functions (#203)
Browse files Browse the repository at this point in the history
Some progress towards refactoring the acqf interface and enabling more
advanced acqfs

Done
- removed `debotorchize`
- incldued some new acqfs, see CHANGELOG
- enable iteration tests with all acqfs
- extended hypothesis tests

Not Done Yet:
- removing `AdapterModel` (does not make sense while #209 is open)

Issues:
- ~~some tests with custom surrogates fail, see separate thread~~
resolved since not using botorch factory anymore
- some of the analytical new acqfs dont seem to work wiht out GP model.
Eg when I implement the `NEI` I get
`botorch.exceptions.errors.UnsupportedError: Only SingleTaskGP models
with known observation noise are currently supported for fantasy-based
NEI & LogNE`. Also the `LogPI` is available in botorch, but it is not
imported to the top level acqf package in botorch, I ignored it here.
~~I included a reverted commit pair so it can be checked out quickly to
reproduce~~
  • Loading branch information
Scienfitz authored Apr 29, 2024
2 parents 4613795 + 580c3ff commit 538ac80
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 223 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `mypy` for search space and objectives
- Class hierarchy for objectives
- Deserialization is now also possible from optional class name abbreviations
- Hypothesis strategies for acquisition functions
- `Kernel` base class allowing to specify kernels
- `MaternKernel` class can be chosen for GP surrogates
- `hypothesis` strategies and roundtrip test for kernels, constraints and objectives
- `hypothesis` strategies and roundtrip test for kernels, constraints, objectives and acquisition
functions
- New acquisition functions: `qSR`, `qNEI`, `LogEI`, `qLogEI`, `qLogNEI`

### Changed
- Reorganized acquisition.py into `acquisition` subpackage
Expand Down
52 changes: 42 additions & 10 deletions baybe/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,67 @@

from baybe.acquisition.acqfs import (
ExpectedImprovement,
LogExpectedImprovement,
PosteriorMean,
ProbabilityOfImprovement,
UpperConfidenceBound,
qExpectedImprovement,
qLogExpectedImprovement,
qLogNoisyExpectedImprovement,
qNoisyExpectedImprovement,
qProbabilityOfImprovement,
qSimpleRegret,
qUpperConfidenceBound,
)

PM = PosteriorMean
qSR = qSimpleRegret
EI = ExpectedImprovement
PI = ProbabilityOfImprovement
UCB = UpperConfidenceBound
qEI = qExpectedImprovement
LogEI = LogExpectedImprovement
qLogEI = qLogExpectedImprovement
qNEI = qNoisyExpectedImprovement
qLogNEI = qLogNoisyExpectedImprovement
PI = ProbabilityOfImprovement
qPI = qProbabilityOfImprovement
UCB = UpperConfidenceBound
qUCB = qUpperConfidenceBound

__all__ = [
# ---------------------------
# Acquisition functions
######################### Acquisition functions
# Posterior Mean
"PosteriorMean",
# Simple Regret
"qSimpleRegret",
# Expected Improvement
"ExpectedImprovement",
"ProbabilityOfImprovement",
"UpperConfidenceBound",
"qExpectedImprovement",
"LogExpectedImprovement",
"qLogExpectedImprovement",
"qNoisyExpectedImprovement",
"qLogNoisyExpectedImprovement",
# Probability of Improvement
"ProbabilityOfImprovement",
"qProbabilityOfImprovement",
# Upper Confidence Bound
"UpperConfidenceBound",
"qUpperConfidenceBound",
# ---------------------------
# Abbreviations
######################### Abbreviations
# Posterior Mean
"PM",
# Simple Regret
"qSR",
# Expected Improvement
"EI",
"PI",
"UCB",
"qEI",
"LogEI",
"qLogEI",
"qNEI",
"qLogNEI",
# Probability of Improvement
"PI",
"qPI",
# Upper Confidence Bound
"UCB",
"qUCB",
]
45 changes: 45 additions & 0 deletions baybe/acquisition/_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Adapter for making BoTorch's acquisition functions work with BayBE models."""

from typing import Any, Callable, Optional

import gpytorch.distributions
from botorch.models.gpytorch import Model
from botorch.posteriors import Posterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from torch import Tensor

from baybe.surrogates.base import Surrogate


class AdapterModel(Model):
"""A BoTorch model that uses a BayBE surrogate model for posterior computation.
Can be used, for example, as an adapter layer for making a BayBE
surrogate model usable in conjunction with BoTorch acquisition functions.
Args:
surrogate: The internal surrogate model
"""

def __init__(self, surrogate: Surrogate):
super().__init__()
self._surrogate = surrogate

@property
def num_outputs(self) -> int: # noqa: D102
# See base class.
# TODO: So far, the usage is limited to single-output models.
return 1

def posterior( # noqa: D102
self,
X: Tensor,
output_indices: Optional[list[int]] = None,
observation_noise: bool = False,
posterior_transform: Optional[Callable[[Posterior], Posterior]] = None,
**kwargs: Any,
) -> Posterior:
# See base class.
mean, var = self._surrogate.posterior(X)
mvn = gpytorch.distributions.MultivariateNormal(mean, var)
return GPyTorchPosterior(mvn)
77 changes: 64 additions & 13 deletions baybe/acquisition/acqfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,31 @@
from typing import ClassVar

from attrs import define, field
from attrs.validators import ge
from attrs.validators import ge, instance_of

from baybe.acquisition.base import AcquisitionFunction


########################################################################################
### Posterior Mean
@define(frozen=True)
class PosteriorMean(AcquisitionFunction):
"""Posterior mean."""

_abbreviation: ClassVar[str] = "PM"


########################################################################################
### Simple Regret
@define(frozen=True)
class qExpectedImprovement(AcquisitionFunction):
"""Monte Carlo based expected improvement."""
class qSimpleRegret(AcquisitionFunction):
"""Monte Carlo based simple regret."""

_abbreviation: ClassVar[str] = "qEI"
_abbreviation: ClassVar[str] = "qSR"


########################################################################################
### Expected Improvement
@define(frozen=True)
class ExpectedImprovement(AcquisitionFunction):
"""Analytical expected improvement."""
Expand All @@ -30,12 +36,48 @@ class ExpectedImprovement(AcquisitionFunction):


@define(frozen=True)
class qProbabilityOfImprovement(AcquisitionFunction):
"""Monte Carlo based probability of improvement."""
class qExpectedImprovement(AcquisitionFunction):
"""Monte Carlo based expected improvement."""

_abbreviation: ClassVar[str] = "qEI"


@define(frozen=True)
class LogExpectedImprovement(AcquisitionFunction):
"""Logarithmic analytical expected improvement."""

_abbreviation: ClassVar[str] = "LogEI"


@define(frozen=True)
class qLogExpectedImprovement(AcquisitionFunction):
"""Logarithmic Monte Carlo based expected improvement."""

_abbreviation: ClassVar[str] = "qLogEI"


@define(frozen=True)
class qNoisyExpectedImprovement(AcquisitionFunction):
"""Monte Carlo based noisy expected improvement."""

_abbreviation: ClassVar[str] = "qNEI"

prune_baseline: bool = field(default=True, validator=instance_of(bool))
"""Auto-prune candidates that are unlikely to be the best."""

_abbreviation: ClassVar[str] = "qPI"

@define(frozen=True)
class qLogNoisyExpectedImprovement(AcquisitionFunction):
"""Logarithmic Monte Carlo based noisy expected improvement."""

_abbreviation: ClassVar[str] = "qLogNEI"

prune_baseline: bool = field(default=True, validator=instance_of(bool))
"""Auto-prune candidates that are unlikely to be the best."""


########################################################################################
### Probability of Improvement
@define(frozen=True)
class ProbabilityOfImprovement(AcquisitionFunction):
"""Analytical probability of improvement."""
Expand All @@ -44,10 +86,19 @@ class ProbabilityOfImprovement(AcquisitionFunction):


@define(frozen=True)
class qUpperConfidenceBound(AcquisitionFunction):
"""Monte Carlo based upper confidence bound."""
class qProbabilityOfImprovement(AcquisitionFunction):
"""Monte Carlo based probability of improvement."""

_abbreviation: ClassVar[str] = "qPI"

_abbreviation: ClassVar[str] = "qUCB"

########################################################################################
### Upper Confidence Bound
@define(frozen=True)
class UpperConfidenceBound(AcquisitionFunction):
"""Analytical upper confidence bound."""

_abbreviation: ClassVar[str] = "UCB"

beta: float = field(converter=float, validator=ge(0.0), default=0.2)
"""Trade-off parameter for mean and variance.
Expand All @@ -59,10 +110,10 @@ class qUpperConfidenceBound(AcquisitionFunction):


@define(frozen=True)
class UpperConfidenceBound(AcquisitionFunction):
"""Analytical upper confidence bound."""
class qUpperConfidenceBound(AcquisitionFunction):
"""Monte Carlo based upper confidence bound."""

_abbreviation: ClassVar[str] = "UCB"
_abbreviation: ClassVar[str] = "qUCB"

beta: float = field(converter=float, validator=ge(0.0), default=0.2)
"""Trade-off parameter for mean and variance.
Expand Down
94 changes: 0 additions & 94 deletions baybe/acquisition/adapter.py

This file was deleted.

Loading

0 comments on commit 538ac80

Please sign in to comment.