Skip to content

Commit

Permalink
Adding qLog(N)EI to get_acquisition_function (#1941)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1941

This is used by the legacy model setup on the Ax side.

Reviewed By: esantorella

Differential Revision: D47633151

fbshipit-source-id: 3a9726dbb362b8bb69c24f27950f11b8561e1ea9
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jul 25, 2023
1 parent 71fd34e commit 7e37fb7
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 66 deletions.
34 changes: 33 additions & 1 deletion botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_acquisition_function(
"acquisition functions."
)
# instantiate and return the requested acquisition function
if acquisition_function_name in ("qEI", "qPI"):
if acquisition_function_name in ("qEI", "qLogEI", "qPI"):
# Since these are the non-noisy variants, use the posterior mean at the observed
# inputs directly to compute the best feasible value without sampling.
Y = model.posterior(X_observed, posterior_transform=posterior_transform).mean
Expand All @@ -125,6 +125,22 @@ def get_acquisition_function(
constraints=constraints,
eta=eta,
)
if acquisition_function_name == "qLogEI":
# putting the import here to avoid circular imports
# ideally, the entire function should be moved out of this file,
# but since it is used for legacy code to be deprecated, we keep it here.
from botorch.acquisition.logei import qLogExpectedImprovement

return qLogExpectedImprovement(
model=model,
best_f=best_f,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
constraints=constraints,
eta=eta,
)
elif acquisition_function_name == "qPI":
return monte_carlo.qProbabilityOfImprovement(
model=model,
Expand All @@ -151,6 +167,22 @@ def get_acquisition_function(
constraints=constraints,
eta=eta,
)
elif acquisition_function_name == "qLogNEI":
from botorch.acquisition.logei import qLogNoisyExpectedImprovement

return qLogNoisyExpectedImprovement(
model=model,
X_baseline=X_observed,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
prune_baseline=kwargs.get("prune_baseline", True),
marginalize_dim=kwargs.get("marginalize_dim"),
cache_root=kwargs.get("cache_root", True),
constraints=constraints,
eta=eta,
)
elif acquisition_function_name == "qSR":
return monte_carlo.qSimpleRegret(
model=model,
Expand Down
Loading

0 comments on commit 7e37fb7

Please sign in to comment.