Skip to content

Commit

Permalink
Fix batch recommendation (emdgroup#348)
Browse files Browse the repository at this point in the history
When the non-GP surrogates where added, they were not designed with
batch recommendations in mind. Now that batch recommendations can be
generally requested, surrogate models providing only marginal posterior
information yield unusable results for batch sizes larger than one.

This PR fixes the issue by:
* Making the `RandomForestSurrogate` produce a proper
`EnsemblePosterior` that is capable of expressing covariance structure,
which is the basic requirement for batch prediction
* Disallowing batch sizes larger than one for surrogates that only
provide marginal posteriors
  • Loading branch information
AdrianSosic authored Sep 3, 2024
2 parents 435c61e + 4dac89d commit 1e387b0
Show file tree
Hide file tree
Showing 15 changed files with 212 additions and 193 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- `py.typed` file to enable the use of type checkers on the user side
- `GaussianSurrogate` base class for surrogate models with Gaussian posteriors
- `IndependentGaussianSurrogate` base class for surrogate models providing independent
Gaussian posteriors for all candidates (cannot be used for batch prediction)
- `comp_rep_columns` property for `Parameter`, `SearchSpace`, `SubspaceDiscrete`
and `SubspaceContinuous` classes
- New mechanisms for surrogate input/output scaling configurable per class
Expand All @@ -32,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `CategoricalParameter` and `TaskParameter` no longer incorrectly coerce a single
string input to categories/tasks
- `farthest_point_sampling` no longer depends on the provided point order
- Batch predictions for `RandomForestSurrogate`
- Surrogates providing only marginal posterior information can no longer be used for
batch recommendation

### Removed
- `register_custom_architecture` decorator
Expand Down
4 changes: 4 additions & 0 deletions baybe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,7 @@ class ModelNotTrainedError(Exception):

class UnmatchedAttributeError(Exception):
"""An attribute cannot be matched against a certain callable signature."""


class InvalidSurrogateModelError(Exception):
"""An invalid surrogate model was chosen."""
14 changes: 12 additions & 2 deletions baybe/recommenders/pure/bayesian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from baybe.acquisition.acqfs import qLogExpectedImprovement
from baybe.acquisition.base import AcquisitionFunction
from baybe.acquisition.utils import convert_acqf
from baybe.exceptions import DeprecationError
from baybe.exceptions import DeprecationError, InvalidSurrogateModelError
from baybe.objectives.base import Objective
from baybe.recommenders.pure.base import PureRecommender
from baybe.searchspace import SearchSpace
from baybe.surrogates import CustomONNXSurrogate, GaussianProcessSurrogate
from baybe.surrogates.base import SurrogateProtocol
from baybe.surrogates.base import IndependentGaussianSurrogate, SurrogateProtocol


@define
Expand Down Expand Up @@ -76,6 +76,16 @@ def recommend( # noqa: D102
f"empty training data."
)

if (
isinstance(self.surrogate_model, IndependentGaussianSurrogate)
and batch_size > 1
):
raise InvalidSurrogateModelError(
f"The specified surrogate model of type "
f"'{self.surrogate_model.__class__.__name__}' "
f"cannot be used for batch recommendation."
)

if isinstance(self.surrogate_model, CustomONNXSurrogate):
CustomONNXSurrogate.validate_compatibility(searchspace)

Expand Down
19 changes: 4 additions & 15 deletions baybe/surrogates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,6 @@ def to_botorch(self) -> Model:
class Surrogate(ABC, SurrogateProtocol, SerialMixin):
"""Abstract base class for all surrogate models."""

# Class variables
joint_posterior: ClassVar[bool]
"""Class variable encoding whether or not a joint posterior is calculated."""

supports_transfer_learning: ClassVar[bool]
"""Class variable encoding whether or not the surrogate supports transfer
learning."""
Expand Down Expand Up @@ -318,8 +314,8 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:


@define
class GaussianSurrogate(Surrogate, ABC):
"""A surrogate model providing Gaussian posterior estimates."""
class IndependentGaussianSurrogate(Surrogate, ABC):
"""A surrogate base class providing independent Gaussian posteriors."""

def _posterior(self, candidates_comp_scaled: Tensor, /) -> GPyTorchPosterior:
# See base class.
Expand All @@ -330,21 +326,14 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> GPyTorchPosterior:

# Construct the Gaussian posterior from the estimated first and second moment
mean, var = self._estimate_moments(candidates_comp_scaled)
if not self.joint_posterior:
var = torch.diag_embed(var)
mvn = MultivariateNormal(mean, var)
mvn = MultivariateNormal(mean, torch.diag_embed(var))
return GPyTorchPosterior(mvn)

@abstractmethod
def _estimate_moments(
self, candidates_comp_scaled: Tensor, /
) -> tuple[Tensor, Tensor]:
"""Estimate first and second moments of the Gaussian posterior.
The second moment may either be a 1-D tensor of marginal variances for the
candidates or a 2-D tensor representing a full covariance matrix over all
candidates, depending on the ``joint_posterior`` flag of the model.
"""
"""Estimate first and second moments of the Gaussian posterior."""


def _make_hook_decode_onnx_str(
Expand Down
19 changes: 8 additions & 11 deletions baybe/surrogates/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
TaskParameter,
)
from baybe.searchspace import SearchSpace
from baybe.surrogates.base import GaussianSurrogate
from baybe.surrogates.utils import batchify
from baybe.surrogates.base import IndependentGaussianSurrogate
from baybe.surrogates.utils import batchify_mean_var_prediction
from baybe.utils.numerical import DTypeFloatONNX

if TYPE_CHECKING:
Expand All @@ -43,20 +43,15 @@ def register_custom_architecture(*args, **kwargs) -> NoReturn:


@define(kw_only=True)
class CustomONNXSurrogate(GaussianSurrogate):
class CustomONNXSurrogate(IndependentGaussianSurrogate):
"""A wrapper class for custom pretrained surrogate models.
Note that these surrogates cannot be retrained.
"""

# Class variables
joint_posterior: ClassVar[bool] = False
# See base class.

supports_transfer_learning: ClassVar[bool] = False
# See base class.

# Object variables
onnx_input_name: str = field(validator=validators.instance_of(str))
"""The input name used for constructing the ONNX str."""

Expand All @@ -78,14 +73,16 @@ def default_model(self) -> ort.InferenceSession:
except Exception as exc:
raise ValueError("Invalid ONNX string") from exc

@batchify
def _estimate_moments(self, candidates_comp: Tensor, /) -> tuple[Tensor, Tensor]:
@batchify_mean_var_prediction
def _estimate_moments(
self, candidates_comp_scaled: Tensor, /
) -> tuple[Tensor, Tensor]:
import torch

from baybe.utils.torch import DTypeFloatTorch

model_inputs = {
self.onnx_input_name: candidates_comp.numpy().astype(DTypeFloatONNX)
self.onnx_input_name: candidates_comp_scaled.numpy().astype(DTypeFloatONNX)
}
results = self._model.run(None, model_inputs)

Expand Down
5 changes: 0 additions & 5 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,9 @@ class GaussianProcessSurrogate(Surrogate):
# to `optimize_acqf_*`, which is configured to be called on the original scale.
# Moving the scaling operation into the botorch GP object avoids this conflict.

# Class variables
joint_posterior: ClassVar[bool] = True
# See base class.

supports_transfer_learning: ClassVar[bool] = True
# See base class.

# Object variables
kernel_factory: KernelFactory = field(
alias="kernel_or_factory",
factory=DefaultKernelFactory,
Expand Down
22 changes: 5 additions & 17 deletions baybe/surrogates/linear.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
"""Linear surrogates.
Currently, the documentation for this surrogate is not available. This is due to a bug
in our documentation tool, see https://github.com/sphinx-doc/sphinx/issues/11750.
Since we plan to refactor the surrogates, this part of the documentation will be
available in the future. Thus, please have a look in the source code directly.
"""
"""Linear surrogates."""

from __future__ import annotations

Expand All @@ -14,8 +7,8 @@
from attr import define, field
from sklearn.linear_model import ARDRegression

from baybe.surrogates.base import GaussianSurrogate
from baybe.surrogates.utils import batchify, catch_constant_targets
from baybe.surrogates.base import IndependentGaussianSurrogate
from baybe.surrogates.utils import batchify_mean_var_prediction, catch_constant_targets
from baybe.surrogates.validation import get_model_params_validator

if TYPE_CHECKING:
Expand All @@ -24,17 +17,12 @@

@catch_constant_targets
@define
class BayesianLinearSurrogate(GaussianSurrogate):
class BayesianLinearSurrogate(IndependentGaussianSurrogate):
"""A Bayesian linear regression surrogate model."""

# Class variables
joint_posterior: ClassVar[bool] = False
# See base class.

supports_transfer_learning: ClassVar[bool] = False
# See base class.

# Object variables
model_params: dict[str, Any] = field(
factory=dict,
converter=dict,
Expand All @@ -45,7 +33,7 @@ class BayesianLinearSurrogate(GaussianSurrogate):
_model: ARDRegression | None = field(init=False, default=None, eq=False)
"""The actual model."""

@batchify
@batchify_mean_var_prediction
def _estimate_moments(
self, candidates_comp_scaled: Tensor, /
) -> tuple[Tensor, Tensor]:
Expand Down
13 changes: 4 additions & 9 deletions baybe/surrogates/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,28 @@

from attr import define, field

from baybe.surrogates.base import GaussianSurrogate
from baybe.surrogates.utils import batchify
from baybe.surrogates.base import IndependentGaussianSurrogate
from baybe.surrogates.utils import batchify_mean_var_prediction

if TYPE_CHECKING:
from torch import Tensor


@define
class MeanPredictionSurrogate(GaussianSurrogate):
class MeanPredictionSurrogate(IndependentGaussianSurrogate):
"""A trivial surrogate model.
It provides the average value of the training targets
as posterior mean and a (data-independent) constant posterior variance.
"""

# Class variables
joint_posterior: ClassVar[bool] = False
# See base class.

supports_transfer_learning: ClassVar[bool] = False
# See base class.

# Object variables
_model: float | None = field(init=False, default=None, eq=False)
"""The estimated posterior mean value of the training targets."""

@batchify
@batchify_mean_var_prediction
def _estimate_moments(
self, candidates_comp_scaled: Tensor, /
) -> tuple[Tensor, Tensor]:
Expand Down
22 changes: 5 additions & 17 deletions baybe/surrogates/ngboost.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
"""NGBoost surrogates.
Currently, the documentation for this surrogate is not available. This is due to a bug
in our documentation tool, see https://github.com/sphinx-doc/sphinx/issues/11750.
Since we plan to refactor the surrogates, this part of the documentation will be
available in the future. Thus, please have a look in the source code directly.
"""
"""NGBoost surrogates."""

from __future__ import annotations

Expand All @@ -15,8 +8,8 @@
from ngboost import NGBRegressor

from baybe.parameters.base import Parameter
from baybe.surrogates.base import GaussianSurrogate
from baybe.surrogates.utils import batchify, catch_constant_targets
from baybe.surrogates.base import IndependentGaussianSurrogate
from baybe.surrogates.utils import batchify_mean_var_prediction, catch_constant_targets
from baybe.surrogates.validation import get_model_params_validator

if TYPE_CHECKING:
Expand All @@ -27,20 +20,15 @@

@catch_constant_targets
@define
class NGBoostSurrogate(GaussianSurrogate):
class NGBoostSurrogate(IndependentGaussianSurrogate):
"""A natural-gradient-boosting surrogate model."""

# Class variables
joint_posterior: ClassVar[bool] = False
# See base class.

supports_transfer_learning: ClassVar[bool] = False
# See base class.

_default_model_params: ClassVar[dict] = {"n_estimators": 25, "verbose": False}
"""Class variable encoding the default model parameters."""

# Object variables
model_params: dict[str, Any] = field(
factory=dict,
converter=dict,
Expand Down Expand Up @@ -70,7 +58,7 @@ def _make_target_scaler_factory() -> type[OutcomeTransform] | None:
# Tree-like models do not require any output scaling
return None

@batchify
@batchify_mean_var_prediction
def _estimate_moments(
self, candidates_comp_scaled: Tensor, /
) -> tuple[Tensor, Tensor]:
Expand Down
Loading

0 comments on commit 1e387b0

Please sign in to comment.