Skip to content

Commit

Permalink
Re-evaluate acqf after post-processing in optimize_acqf (pytorch#1840)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1840

Resolves pytorch#1736

Reviewed By: Balandat

Differential Revision: D44278568

fbshipit-source-id: a7f45676765db178f346ab836b4757a695ce25da
  • Loading branch information
saitcakmak authored and facebook-github-bot committed May 18, 2023
1 parent 8c9d54b commit 70d0c63
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
6 changes: 6 additions & 0 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,12 @@ def _optimize_batch_candidates(

if opt_inputs.post_processing_func is not None:
batch_candidates = opt_inputs.post_processing_func(batch_candidates)
with torch.no_grad():
acq_values_list = [
opt_inputs.acq_function(cand)
for cand in batch_candidates.split(batch_limit, dim=0)
]
batch_acq_values = torch.cat(acq_values_list, dim=0)

if opt_inputs.return_best_only:
best = torch.argmax(batch_acq_values.view(-1), dim=0)
Expand Down
22 changes: 12 additions & 10 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_optimize_acqf_sequential(
num_restarts = 2
raw_samples = 10
options = {}
for dtype in (torch.float, torch.double):
for dtype, use_rounding in ((torch.float, True), (torch.double, False)):
mock_acq_function = MockAcquisitionFunction()
mock_gen_batch_initial_conditions.side_effect = [
torch.zeros(num_restarts, 1, 3, device=self.device, dtype=dtype)
Expand All @@ -285,9 +285,6 @@ def test_optimize_acqf_sequential(
for i in range(q)
]
mock_gen_candidates.side_effect = gcs_return_vals
expected_candidates = torch.cat(
[cands[0] for cands, _ in gcs_return_vals], dim=-2
).round()
bounds = torch.stack(
[
torch.zeros(3, device=self.device, dtype=dtype),
Expand All @@ -306,18 +303,23 @@ def test_optimize_acqf_sequential(
raw_samples=raw_samples,
options=options,
inequality_constraints=inequality_constraints,
post_processing_func=rounding_func,
post_processing_func=rounding_func if use_rounding else None,
sequential=True,
timeout_sec=timeout_sec,
gen_candidates=mock_gen_candidates,
)
self.assertEqual(mock_gen_candidates.call_count, q)
self.assertTrue(torch.equal(candidates, expected_candidates))
self.assertTrue(
torch.equal(
acq_value, torch.cat([acqval for _, acqval in gcs_return_vals])
)
base_candidates = torch.cat(
[cands[0] for cands, _ in gcs_return_vals], dim=-2
)
if use_rounding:
expected_candidates = base_candidates.round()
expected_val = mock_acq_function(expected_candidates.unsqueeze(-2))
else:
expected_candidates = base_candidates
expected_val = torch.cat([acqval for _, acqval in gcs_return_vals])
self.assertTrue(torch.equal(candidates, expected_candidates))
self.assertTrue(torch.equal(acq_value, expected_val))
# verify error when using a OneShotAcquisitionFunction
with self.assertRaises(NotImplementedError):
optimize_acqf(
Expand Down

0 comments on commit 70d0c63

Please sign in to comment.