-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge: New acquisition functions (#203)
Some progress towards refactoring the acqf interface and enabling more advanced acqfs Done - removed `debotorchize` - incldued some new acqfs, see CHANGELOG - enable iteration tests with all acqfs - extended hypothesis tests Not Done Yet: - removing `AdapterModel` (does not make sense while #209 is open) Issues: - ~~some tests with custom surrogates fail, see separate thread~~ resolved since not using botorch factory anymore - some of the analytical new acqfs dont seem to work wiht out GP model. Eg when I implement the `NEI` I get `botorch.exceptions.errors.UnsupportedError: Only SingleTaskGP models with known observation noise are currently supported for fantasy-based NEI & LogNE`. Also the `LogPI` is available in botorch, but it is not imported to the top level acqf package in botorch, I ignored it here. ~~I included a reverted commit pair so it can be checked out quickly to reproduce~~
- Loading branch information
Showing
17 changed files
with
368 additions
and
223 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
"""Adapter for making BoTorch's acquisition functions work with BayBE models.""" | ||
|
||
from typing import Any, Callable, Optional | ||
|
||
import gpytorch.distributions | ||
from botorch.models.gpytorch import Model | ||
from botorch.posteriors import Posterior | ||
from botorch.posteriors.gpytorch import GPyTorchPosterior | ||
from torch import Tensor | ||
|
||
from baybe.surrogates.base import Surrogate | ||
|
||
|
||
class AdapterModel(Model): | ||
"""A BoTorch model that uses a BayBE surrogate model for posterior computation. | ||
Can be used, for example, as an adapter layer for making a BayBE | ||
surrogate model usable in conjunction with BoTorch acquisition functions. | ||
Args: | ||
surrogate: The internal surrogate model | ||
""" | ||
|
||
def __init__(self, surrogate: Surrogate): | ||
super().__init__() | ||
self._surrogate = surrogate | ||
|
||
@property | ||
def num_outputs(self) -> int: # noqa: D102 | ||
# See base class. | ||
# TODO: So far, the usage is limited to single-output models. | ||
return 1 | ||
|
||
def posterior( # noqa: D102 | ||
self, | ||
X: Tensor, | ||
output_indices: Optional[list[int]] = None, | ||
observation_noise: bool = False, | ||
posterior_transform: Optional[Callable[[Posterior], Posterior]] = None, | ||
**kwargs: Any, | ||
) -> Posterior: | ||
# See base class. | ||
mean, var = self._surrogate.posterior(X) | ||
mvn = gpytorch.distributions.MultivariateNormal(mean, var) | ||
return GPyTorchPosterior(mvn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.