Skip to content

Commit

Permalink
Merge: Remove get_surrogate Restrictions (#386)
Browse files Browse the repository at this point in the history
Fixes #385 
- enables the extraction of surrogates for transformed single targets or
desirability objects
- adds a few tests
- mentions `get_surrogate` in the user guide
  • Loading branch information
Scienfitz authored Sep 30, 2024
2 parents 02ad6e2 + cfde6bb commit f069cdd
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 19 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Changed
- `get_surrogate` now also returns the model for transformed single targets or
desirability objectives

### Fixed
- Unsafe name-based matching of columns in `get_comp_rep_parameter_indices`

Expand Down
21 changes: 7 additions & 14 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from baybe.exceptions import IncompatibilityError
from baybe.objectives.base import Objective, to_objective
from baybe.objectives.single import SingleTargetObjective
from baybe.parameters.base import Parameter
from baybe.recommenders.base import RecommenderProtocol
from baybe.recommenders.meta.base import MetaRecommender
Expand All @@ -29,7 +28,6 @@
from baybe.serialization import SerialMixin, converter
from baybe.surrogates.base import SurrogateProtocol
from baybe.targets.base import Target
from baybe.targets.numerical import NumericalTarget
from baybe.telemetry import (
TELEM_LABELS,
telemetry_record_recommended_measurement_percentage,
Expand Down Expand Up @@ -317,19 +315,14 @@ def get_surrogate(self) -> SurrogateProtocol:
Returns:
Surrogate: The surrogate of the current recommender.
"""
# TODO: remove temporary restriction when target transformations can be handled
match self.objective:
case SingleTargetObjective(
_target=NumericalTarget(bounds=b)
) if not b.is_bounded:
pass
case _:
raise NotImplementedError(
"Surrogate model access is currently only supported for a single "
"untransformed target."
)
Note:
Currently, this method always returns the surrogate model with respect to
the transformed target(s) / objective. This means that if you are using a
``SingleTargetObjective`` with a transformed target or a
``DesirabilityObjective``, the model's output will correspond to the
transformed quantities and not the original untransformed target(s).
"""
if self.objective is None:
raise IncompatibilityError(
f"No surrogate is available since no '{Objective.__name__}' is defined."
Expand Down
51 changes: 46 additions & 5 deletions docs/userguide/surrogates.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,59 @@
# Surrogates

Surrogate models are used to model and estimate the unknown objective function of the DoE campaign. BayBE offers a diverse array of surrogate models, while also allowing for the utilization of custom models. All surrogate models are based upon the general [`Surrogate`](baybe.surrogates.base.Surrogate) class. Some models even support transfer learning, as indicated by the `supports_transfer_learning` attribute.
Surrogate models are used to model and estimate the unknown objective function of the
DoE campaign. BayBE offers a diverse array of surrogate models, while also allowing for
the utilization of custom models. All surrogate models are based upon the general
[`Surrogate`](baybe.surrogates.base.Surrogate) class. Some models even support transfer
learning, as indicated by the `supports_transfer_learning` attribute.

## Available models
## Available Models

BayBE provides a comprehensive selection of surrogate models, empowering you to choose the most suitable option for your specific needs. The following surrogate models are available within BayBE:
BayBE provides a comprehensive selection of surrogate models, empowering you to choose
the most suitable option for your specific needs. The following surrogate models are
available within BayBE:

* [`GaussianProcessSurrogate`](baybe.surrogates.gaussian_process.core.GaussianProcessSurrogate)
* [`BayesianLinearSurrogate`](baybe.surrogates.linear.BayesianLinearSurrogate)
* [`MeanPredictionSurrogate`](baybe.surrogates.naive.MeanPredictionSurrogate)
* [`NGBoostSurrogate`](baybe.surrogates.ngboost.NGBoostSurrogate)
* [`RandomForestSurrogate`](baybe.surrogates.random_forest.RandomForestSurrogate)

## Extracting the Model for Advanced Study

## Using custom models
In principle, the surrogate model does not need to be a persistent object during
Bayesian optimization since each iteration performs a new fit anyway. However, for
advanced study, such as investigating the posterior predictions, acquisition functions
or feature importance, it can be useful to diretly extract the current surrogate model.

BayBE goes one step further by allowing you to incorporate custom models based on the ONNX architecture. Note however that these cannot be retrained. For a detailed explanation on using custom models, refer to the comprehensive examples provided in the corresponding [example folder](./../../examples/Custom_Surrogates/Custom_Surrogates).
For this, BayBE provides the ``get_surrogate`` method, which is available for the
[``Campaign``](baybe.campaign.Campaign.get_surrogate) or for
[recommenders](baybe.recommenders.pure.bayesian.base.BayesianRecommender.get_surrogate).
Below an example of how to utilize this in conjunction with the popular SHAP package:

~~~python
# Assuming we already have a campaign created and measurements added
data = campaign.measurements[[p.name for p in campaign.parameters]]
model = lambda x: campaign.get_surrogate().posterior(x).mean

# Apply SHAP
explainer = shap.Explainer(model, data)
shap_values = explainer(data)
shap.plots.bar(shap_values)
~~~

```{admonition} Current Scalarization Limitations
:class: note
Currently, ``get_surrogate`` always returns the surrogate model with respect to the
transformed target(s) / objective. This means that if you are using a
``SingleTargetObjective`` with a transformed target or a ``DesirabilityObjective``, the
model's output will correspond to the transformed quantities and not the original
untransformed target(s). If you are using the model for subsequent analysis this should
be kept in mind.
```

## Using Custom Models

BayBE goes one step further by allowing you to incorporate custom models based on the
ONNX architecture. Note however that these cannot be retrained. For a detailed
explanation on using custom models, refer to the comprehensive examples provided in the
corresponding [example folder](./../../examples/Custom_Surrogates/Custom_Surrogates).
31 changes: 31 additions & 0 deletions tests/test_campaign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Tests features of the Campaign object."""

import pytest
from pytest import param

from .conftest import run_iterations


@pytest.mark.parametrize(
"target_names",
[
param(["Target_max"], id="max"),
param(["Target_min"], id="min"),
param(["Target_max_bounded"], id="max_b"),
param(["Target_min_bounded"], id="min_b"),
param(["Target_match_bell"], id="match_bell"),
param(["Target_match_triangular"], id="match_tri"),
param(
["Target_max_bounded", "Target_min_bounded", "Target_match_triangular"],
id="desirability",
),
],
)
@pytest.mark.parametrize("batch_size", [2], ids=["b2"])
@pytest.mark.parametrize("n_iterations", [2], ids=["i2"])
def test_get_surrogate(campaign, n_iterations, batch_size):
"""Test successful extraction of the surrogate model."""
run_iterations(campaign, n_iterations, batch_size)

model = campaign.get_surrogate()
assert model is not None, "Something went wrong during surrogate model extraction."

0 comments on commit f069cdd

Please sign in to comment.