Skip to content

Commit

Permalink
make qNIPV not an AnalyticAcquisitionFunction; optimize_acqf suppor…
Browse files Browse the repository at this point in the history
…t clarity (pytorch#2286)

Summary:

I put a `LogExpectedImprovement` instance into `optimize_acqf`, and when I got an error about it not having an attribute `X_pending`, I was not sure if this was a bug or if I did something known to be unsupported.

- Make `qNegIntegratedPosteriorVariance` inherit from `AcquisitionFunction` rather than `AnalyticAcquisitionFunction`, because the functionality it was inheriting from `AnalyticAcquisitionFunction` was not relevant.
- `qNegIntegratedPosteriorVariance` loses an error message about not supporting multi-output with a `PosteriorTransform` that is not scalarized and gains a unit test showing that it does.

Differential Revision: D55843171
  • Loading branch information
esantorella authored and facebook-github-bot committed Apr 6, 2024
1 parent 968e465 commit a8f6565
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
7 changes: 4 additions & 3 deletions botorch/acquisition/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import torch
from botorch import settings
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.models.model import Model
Expand All @@ -37,7 +37,7 @@
from torch import Tensor


class qNegIntegratedPosteriorVariance(AnalyticAcquisitionFunction):
class qNegIntegratedPosteriorVariance(AcquisitionFunction):
r"""Batch Integrated Negative Posterior Variance for Active Learning.
This acquisition function quantifies the (negative) integrated posterior variance
Expand Down Expand Up @@ -75,7 +75,8 @@ def __init__(
points that have been submitted for function evaluation but
have not yet been evaluated.
"""
super().__init__(model=model, posterior_transform=posterior_transform)
super().__init__(model=model)
self.posterior_transform = posterior_transform
if sampler is None:
# If no sampler is provided, we use the following dummy sampler for the
# fantasize() method in forward. IMPORTANT: This assumes that the posterior
Expand Down
4 changes: 3 additions & 1 deletion botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _raise_deprecation_warning_if_kwargs(fn_name: str, kwargs: Dict[str, Any]) -
f"`{fn_name}` does not support arguments {list(kwargs.keys())}. In "
"the future, this will become an error.",
DeprecationWarning,
stacklevel=2,
)


Expand Down Expand Up @@ -366,7 +367,7 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new "
"set of initial conditions."
)
warnings.warn(first_warn_msg, RuntimeWarning)
warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2)

if not initial_conditions_provided:
batch_initial_conditions = opt_inputs.get_ic_generator()(
Expand All @@ -392,6 +393,7 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
"Optimization failed on the second try, after generating a "
"new set of initial conditions.",
RuntimeWarning,
stacklevel=2,
)

if opt_inputs.post_processing_func is not None:
Expand Down
40 changes: 23 additions & 17 deletions test/acquisition/test_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,31 +77,37 @@ def test_q_neg_int_post_variance(self):
val = qNIPV(X)
val_exp = -variance.mean(dim=-2).squeeze(-1)
self.assertAllClose(val, val_exp, atol=1e-4)

# multi-output model
mean = torch.zeros(4, 2, device=self.device, dtype=dtype)
variance = torch.rand(4, 2, device=self.device, dtype=dtype)
cov = torch.diag_embed(variance.view(-1))
f_posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov))
mc_points = torch.rand(10, 1, device=self.device, dtype=dtype)
mfm = MockModel(f_posterior)
with mock.patch.object(MockModel, "fantasize", return_value=mfm):
with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs:
mock_num_outputs.return_value = 2
mm = MockModel(None)
with mock.patch.object(
MockModel, "fantasize", return_value=mfm
), mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs:
mock_num_outputs.return_value = 2
mm = MockModel(None)

weights = torch.tensor([0.5, 0.5], device=self.device, dtype=dtype)
qNIPV = qNegIntegratedPosteriorVariance(
model=mm,
mc_points=mc_points,
posterior_transform=ScalarizedPosteriorTransform(weights=weights),
)
X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy
val = qNIPV(X)
self.assertAllClose(val, -0.5 * variance.mean(), atol=1e-4)
# without posterior_transform
qNIPV = qNegIntegratedPosteriorVariance(
model=mm,
mc_points=mc_points,
)
val = qNIPV(X)
self.assertAllClose(val, -variance.mean(0), atol=1e-4)

weights = torch.tensor([0.5, 0.5], device=self.device, dtype=dtype)
qNIPV = qNegIntegratedPosteriorVariance(
model=mm,
mc_points=mc_points,
posterior_transform=ScalarizedPosteriorTransform(
weights=weights
),
)
X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy
val = qNIPV(X)
self.assertTrue(
torch.allclose(val, -0.5 * variance.mean(), atol=1e-4)
)
# batched multi-output model
mean = torch.zeros(4, 3, 1, 2, device=self.device, dtype=dtype)
variance = torch.rand(4, 3, 1, 2, device=self.device, dtype=dtype)
Expand Down

0 comments on commit a8f6565

Please sign in to comment.