Skip to content

Commit

Permalink
Updated optimize_objective (input_constructors.py) to allow specifi…
Browse files Browse the repository at this point in the history
…cation of acquisition function as kwarg (#2516)

Summary:
Pull Request resolved: #2516

Support an optional acquisition function instance as a keyword argument, useful for multi-fidelity acquisition functions.

Reviewed By: SebastianAment

Differential Revision: D62380369

fbshipit-source-id: 0131e677415b487de6adfd6d5c828d4904be4dc2
  • Loading branch information
ltiao authored and facebook-github-bot committed Sep 9, 2024
1 parent 571fe66 commit 4dc1271
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
28 changes: 16 additions & 12 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,7 @@ def optimize_objective(
model: Model,
bounds: Tensor,
q: int,
acq_function: Optional[AcquisitionFunction] = None,
objective: Optional[MCAcquisitionObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
linear_constraints: Optional[tuple[Tensor, Tensor]] = None,
Expand Down Expand Up @@ -1574,18 +1575,21 @@ def optimize_objective(
if optimizer_options is None:
optimizer_options = {}

if objective is not None:
sampler_cls = SobolQMCNormalSampler if qmc else IIDNormalSampler
acq_function = qSimpleRegret(
model=model,
objective=objective,
posterior_transform=posterior_transform,
sampler=sampler_cls(sample_shape=torch.Size([mc_samples]), seed=seed_inner),
)
else:
acq_function = PosteriorMean(
model=model, posterior_transform=posterior_transform
)
if acq_function is None:
if objective is None:
acq_function = PosteriorMean(
model=model, posterior_transform=posterior_transform
)
else:
sampler_cls = SobolQMCNormalSampler if qmc else IIDNormalSampler
acq_function = qSimpleRegret(
model=model,
objective=objective,
posterior_transform=posterior_transform,
sampler=sampler_cls(
sample_shape=torch.Size([mc_samples]), seed=seed_inner
),
)

if fixed_features:
acq_function = FixedFeatureAcquisitionFunction(
Expand Down
10 changes: 10 additions & 0 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,16 @@ def test_optimize_objective(self, mock_optimize_acqf):
mock_model = self.mock_model
bounds = torch.rand(2, len(self.bounds))

with self.subTest("scalarObjective_acqusitionFunction"):
optimize_objective(
model=mock_model,
bounds=bounds,
q=1,
acq_function=UpperConfidenceBound(model=mock_model, beta=0.1),
)
kwargs = mock_optimize_acqf.call_args[1]
self.assertIsInstance(kwargs["acq_function"], UpperConfidenceBound)

A = torch.rand(1, bounds.shape[-1])
b = torch.zeros([1, 1])
idx = A[0].nonzero(as_tuple=False).squeeze()
Expand Down

0 comments on commit 4dc1271

Please sign in to comment.