Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MIN mode via acquisition function #340

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
in the recommender but in the surrogate
- Fallback models created by `catch_constant_targets` are stored outside the surrogate
- `to_tensor` now also handles `numpy` arrays
- `MIN` mode of `NumericalTarget` is now implemented via the acquisition function
instead of negating the computational representation

### Fixed
- `CategoricalParameter` and `TaskParameter` no longer incorrectly coerce a single
Expand Down
35 changes: 30 additions & 5 deletions baybe/acquisition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from attrs import define

from baybe.objectives.base import Objective
from baybe.objectives.desirability import DesirabilityObjective
from baybe.objectives.single import SingleTargetObjective
from baybe.searchspace.core import SearchSpace
from baybe.serialization.core import (
converter,
Expand All @@ -19,6 +21,8 @@
)
from baybe.serialization.mixin import SerialMixin
from baybe.surrogates.base import SurrogateProtocol
from baybe.targets.enum import TargetMode
from baybe.targets.numerical import NumericalTarget
from baybe.utils.basic import classproperty, match_attributes
from baybe.utils.boolean import is_abstract
from baybe.utils.dataframe import to_tensor
Expand Down Expand Up @@ -53,14 +57,16 @@ def to_botorch(
The required structure of `measurements` is specified in
:meth:`baybe.recommenders.base.RecommenderProtocol.recommend`.
"""
import botorch.acquisition as botorch_acqf_module
import botorch.acquisition as bo_acqf
import torch
from botorch.acquisition.objective import LinearMCObjective

# Get computational data representations
train_x = searchspace.transform(measurements, allow_extra=True)
train_y = objective.transform(measurements)

# Retrieve corresponding botorch class
acqf_cls = getattr(botorch_acqf_module, self.__class__.__name__)
acqf_cls = getattr(bo_acqf, self.__class__.__name__)

# Match relevant attributes
params_dict = match_attributes(
Expand All @@ -72,17 +78,36 @@ def to_botorch(
additional_params = {}
if "model" in signature_params:
additional_params["model"] = surrogate.to_botorch()
if "best_f" in signature_params:
additional_params["best_f"] = train_y.max().item()
if "X_baseline" in signature_params:
additional_params["X_baseline"] = to_tensor(train_x)
if "mc_points" in signature_params:
additional_params["mc_points"] = to_tensor(
self.get_integration_points(searchspace) # type: ignore[attr-defined]
)

params_dict.update(additional_params)
# Add acquisition objective / best observed value
match objective:
AVHopp marked this conversation as resolved.
Show resolved Hide resolved
case SingleTargetObjective(NumericalTarget(mode=TargetMode.MIN)):
if "best_f" in signature_params:
AVHopp marked this conversation as resolved.
Show resolved Hide resolved
additional_params["best_f"] = train_y.min().item()

if issubclass(acqf_cls, bo_acqf.AnalyticAcquisitionFunction):
additional_params["maximize"] = False
AdrianSosic marked this conversation as resolved.
Show resolved Hide resolved
elif issubclass(acqf_cls, bo_acqf.MCAcquisitionFunction):
additional_params["objective"] = LinearMCObjective(
torch.tensor([-1.0])
)
else:
raise ValueError(
f"Unsupported acquisition function type: {acqf_cls}."
)
case SingleTargetObjective() | DesirabilityObjective():
AVHopp marked this conversation as resolved.
Show resolved Hide resolved
if "best_f" in signature_params:
additional_params["best_f"] = train_y.max().item()
case _:
raise ValueError(f"Unsupported objective type: {objective}")

params_dict.update(additional_params)
return acqf_cls(**params_dict)


Expand Down
7 changes: 0 additions & 7 deletions baybe/targets/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,6 @@ def transform(self, data: pd.DataFrame) -> pd.DataFrame: # noqa: D102
transformed = pd.DataFrame(
func(data, *self.bounds.to_tuple()), index=data.index
)

# Otherwise, simply negate all target values for ``MIN`` mode.
# For ``MAX`` mode, nothing needs to be done.
# For ``MATCH`` mode, the validators avoid a situation without specified bounds.
elif self.mode is TargetMode.MIN:
transformed = -data

else:
transformed = data.copy()

Expand Down
12 changes: 11 additions & 1 deletion streamlit/surrogate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,26 @@ def main():
}

# Streamlit simulation parameters
st.sidebar.markdown("# Domain")
st_random_seed = int(st.sidebar.number_input("Random seed", value=1337))
st_function_name = st.sidebar.selectbox(
"Test function", list(test_functions.keys())
)
st_target_mode = st.sidebar.radio(
"Objective",
["MAX", "MIN"],
format_func=lambda x: {"MAX": "Maximization", "MIN": "Minimization"}[x],
horizontal=True,
)
st.sidebar.markdown("---")
st.sidebar.markdown("# Model")
st_surrogate_name = st.sidebar.selectbox(
"Surrogate model", list(surrogate_model_classes.keys())
)
st_n_training_points = st.sidebar.slider("Number of training points", 1, 20, 5)
st_n_recommendations = st.sidebar.slider("Number of recommendations", 1, 20, 5)
st.sidebar.markdown("---")
st.sidebar.markdown("# Validation")
st.sidebar.markdown(
"""
When scaling is implemented correctly, the plot should remain static (except for
Expand Down Expand Up @@ -139,7 +149,7 @@ def main():
),
)
searchspace = SearchSpace.from_product(parameters=[parameter])
objective = NumericalTarget(name="y", mode="MAX").to_objective()
objective = NumericalTarget(name="y", mode=st_target_mode).to_objective()

# Create the surrogate model and the recommender
surrogate_model = surrogate_model_classes[st_surrogate_name]()
Expand Down