-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge: Remove
get_surrogate
Restrictions (#386)
Fixes #385 - enables the extraction of surrogates for transformed single targets or desirability objects - adds a few tests - mentions `get_surrogate` in the user guide
- Loading branch information
Showing
4 changed files
with
88 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,59 @@ | ||
# Surrogates | ||
|
||
Surrogate models are used to model and estimate the unknown objective function of the DoE campaign. BayBE offers a diverse array of surrogate models, while also allowing for the utilization of custom models. All surrogate models are based upon the general [`Surrogate`](baybe.surrogates.base.Surrogate) class. Some models even support transfer learning, as indicated by the `supports_transfer_learning` attribute. | ||
Surrogate models are used to model and estimate the unknown objective function of the | ||
DoE campaign. BayBE offers a diverse array of surrogate models, while also allowing for | ||
the utilization of custom models. All surrogate models are based upon the general | ||
[`Surrogate`](baybe.surrogates.base.Surrogate) class. Some models even support transfer | ||
learning, as indicated by the `supports_transfer_learning` attribute. | ||
|
||
## Available models | ||
## Available Models | ||
|
||
BayBE provides a comprehensive selection of surrogate models, empowering you to choose the most suitable option for your specific needs. The following surrogate models are available within BayBE: | ||
BayBE provides a comprehensive selection of surrogate models, empowering you to choose | ||
the most suitable option for your specific needs. The following surrogate models are | ||
available within BayBE: | ||
|
||
* [`GaussianProcessSurrogate`](baybe.surrogates.gaussian_process.core.GaussianProcessSurrogate) | ||
* [`BayesianLinearSurrogate`](baybe.surrogates.linear.BayesianLinearSurrogate) | ||
* [`MeanPredictionSurrogate`](baybe.surrogates.naive.MeanPredictionSurrogate) | ||
* [`NGBoostSurrogate`](baybe.surrogates.ngboost.NGBoostSurrogate) | ||
* [`RandomForestSurrogate`](baybe.surrogates.random_forest.RandomForestSurrogate) | ||
|
||
## Extracting the Model for Advanced Study | ||
|
||
## Using custom models | ||
In principle, the surrogate model does not need to be a persistent object during | ||
Bayesian optimization since each iteration performs a new fit anyway. However, for | ||
advanced study, such as investigating the posterior predictions, acquisition functions | ||
or feature importance, it can be useful to diretly extract the current surrogate model. | ||
|
||
BayBE goes one step further by allowing you to incorporate custom models based on the ONNX architecture. Note however that these cannot be retrained. For a detailed explanation on using custom models, refer to the comprehensive examples provided in the corresponding [example folder](./../../examples/Custom_Surrogates/Custom_Surrogates). | ||
For this, BayBE provides the ``get_surrogate`` method, which is available for the | ||
[``Campaign``](baybe.campaign.Campaign.get_surrogate) or for | ||
[recommenders](baybe.recommenders.pure.bayesian.base.BayesianRecommender.get_surrogate). | ||
Below an example of how to utilize this in conjunction with the popular SHAP package: | ||
|
||
~~~python | ||
# Assuming we already have a campaign created and measurements added | ||
data = campaign.measurements[[p.name for p in campaign.parameters]] | ||
model = lambda x: campaign.get_surrogate().posterior(x).mean | ||
|
||
# Apply SHAP | ||
explainer = shap.Explainer(model, data) | ||
shap_values = explainer(data) | ||
shap.plots.bar(shap_values) | ||
~~~ | ||
|
||
```{admonition} Current Scalarization Limitations | ||
:class: note | ||
Currently, ``get_surrogate`` always returns the surrogate model with respect to the | ||
transformed target(s) / objective. This means that if you are using a | ||
``SingleTargetObjective`` with a transformed target or a ``DesirabilityObjective``, the | ||
model's output will correspond to the transformed quantities and not the original | ||
untransformed target(s). If you are using the model for subsequent analysis this should | ||
be kept in mind. | ||
``` | ||
|
||
## Using Custom Models | ||
|
||
BayBE goes one step further by allowing you to incorporate custom models based on the | ||
ONNX architecture. Note however that these cannot be retrained. For a detailed | ||
explanation on using custom models, refer to the comprehensive examples provided in the | ||
corresponding [example folder](./../../examples/Custom_Surrogates/Custom_Surrogates). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""Tests features of the Campaign object.""" | ||
|
||
import pytest | ||
from pytest import param | ||
|
||
from .conftest import run_iterations | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"target_names", | ||
[ | ||
param(["Target_max"], id="max"), | ||
param(["Target_min"], id="min"), | ||
param(["Target_max_bounded"], id="max_b"), | ||
param(["Target_min_bounded"], id="min_b"), | ||
param(["Target_match_bell"], id="match_bell"), | ||
param(["Target_match_triangular"], id="match_tri"), | ||
param( | ||
["Target_max_bounded", "Target_min_bounded", "Target_match_triangular"], | ||
id="desirability", | ||
), | ||
], | ||
) | ||
@pytest.mark.parametrize("batch_size", [2], ids=["b2"]) | ||
@pytest.mark.parametrize("n_iterations", [2], ids=["i2"]) | ||
def test_get_surrogate(campaign, n_iterations, batch_size): | ||
"""Test successful extraction of the surrogate model.""" | ||
run_iterations(campaign, n_iterations, batch_size) | ||
|
||
model = campaign.get_surrogate() | ||
assert model is not None, "Something went wrong during surrogate model extraction." |