Skip to content

Commit

Permalink
Avoid batch recommendation with marginal surrogate models
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Sep 2, 2024
1 parent 25ccfb8 commit f3c2ac7
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
6 changes: 3 additions & 3 deletions examples/Custom_Surrogates/custom_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
### Iterate with recommendations and measurements

# Let's do a first round of recommendation
recommendation = campaign.recommend(batch_size=2)
recommendation = campaign.recommend(batch_size=1)

print("Recommendation from campaign:")
print(recommendation)
Expand All @@ -122,10 +122,10 @@

### Model Outputs

# Do another round of recommendations
# Do another round of recommendation
recommendation = campaign.recommend(batch_size=2)

# Print second round of recommendations
# Print second round of recommendation

print("Recommendation from campaign:")
print(recommendation)
Expand Down
8 changes: 4 additions & 4 deletions examples/Custom_Surrogates/surrogate_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
print(surrogate_model.to_json(), end="\n" * 3)

# Let's do a first round of recommendation
recommendation = campaign.recommend(batch_size=2)
recommendation = campaign.recommend(batch_size=1)

print("Recommendation from campaign:")
print(recommendation)
Expand All @@ -112,11 +112,11 @@

print("Here you will see some model outputs as we set verbose to True")

# Do another round of recommendations
recommendation = campaign.recommend(batch_size=2)
# Do another round of recommendation
recommendation = campaign.recommend(batch_size=1)


# Print second round of recommendations
# Print second round of recommendation

print("Recommendation from campaign:")
print(recommendation)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_iterations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
from baybe.recommenders.pure.nonpredictive.base import NonPredictiveRecommender
from baybe.searchspace import SearchSpaceType
from baybe.surrogates.base import Surrogate
from baybe.surrogates.base import IndependentGaussianSurrogate, Surrogate
from baybe.surrogates.custom import CustomONNXSurrogate
from baybe.surrogates.gaussian_process.presets import (
DefaultKernelFactory,
Expand Down Expand Up @@ -249,7 +249,9 @@ def test_kernel_factories(campaign, n_iterations, batch_size):
valid_surrogate_models,
ids=[c.__class__ for c in valid_surrogate_models],
)
def test_surrogate_models(campaign, n_iterations, batch_size):
def test_surrogate_models(campaign, n_iterations, batch_size, surrogate_model):
if batch_size > 1 and isinstance(surrogate_model, IndependentGaussianSurrogate):
pytest.skip("Batch recommendation is not supported.")
run_iterations(campaign, n_iterations, batch_size)


Expand Down

0 comments on commit f3c2ac7

Please sign in to comment.