-
Notifications
You must be signed in to change notification settings - Fork 404
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Decoupled Acquisition Function (#1948)
Summary: Pull Request resolved: #1948 Introduce an abstract class for decoupled acquisition functions. A decoupled acquisition function where one may intend to evaluate a design on only a subset of the outcomes. Typically this would be handled by fantasizing, where one would fantasize as to what the partial observation would be if one were to evaluate a design on the subset of outcomes (e.g. you only fantasize at those outcomes) Reviewed By: esantorella Differential Revision: D47710904 fbshipit-source-id: e61b3555c5fd93b53990ce3af299650bbb5341e1
- Loading branch information
1 parent
cca54db
commit abe786a
Showing
4 changed files
with
305 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
r"""Abstract base module for decoupled acquisition functions.""" | ||
|
||
from __future__ import annotations | ||
|
||
import warnings | ||
from abc import ABC | ||
from typing import Optional | ||
|
||
import torch | ||
from botorch.acquisition.acquisition import AcquisitionFunction | ||
from botorch.exceptions import BotorchWarning | ||
from botorch.exceptions.errors import BotorchTensorDimensionError | ||
from botorch.logging import shape_to_str | ||
|
||
from botorch.models.model import ModelList | ||
from torch import Tensor | ||
|
||
|
||
class DecoupledAcquisitionFunction(AcquisitionFunction, ABC): | ||
""" | ||
Abstract base class for decoupled acquisition functions. | ||
A decoupled acquisition function where one may intend to | ||
evaluate a design on only a subset of the outcomes. | ||
Typically this would be handled by fantasizing, where one | ||
would fantasize as to what the partial observation would | ||
be if one were to evaluate a design on the subset of | ||
outcomes (e.g. you only fantasize at those outcomes). The | ||
`X_evaluation_mask` specifies which outcomes should be | ||
evaluated for each design. `X_evaluation_mask` is `q x m`, | ||
where there are q design points in the batch and m outcomes. | ||
In the asynchronous case, where there are n' pending points, | ||
we need to track which outcomes each pending point should be | ||
evaluated on. In this case, we concatenate | ||
`X_pending_evaluation_mask` with `X_evaluation_mask` to obtain | ||
the full evaluation_mask. | ||
This abstract class handles generating and updating an evaluation mask, | ||
which is a boolean tensor indicating which outcomes a given design is | ||
being evaluated on. The evaluation mask has shape `(n' + q) x m`, where | ||
n' is the number of pending points and the q represents the new | ||
candidates to be generated. | ||
If `X(_pending)_evaluation_mas`k is None, it is assumed that `X(_pending)` | ||
will be evaluated on all outcomes. | ||
""" | ||
|
||
def __init__( | ||
self, model: ModelList, X_evaluation_mask: Optional[Tensor] = None, **kwargs | ||
) -> None: | ||
r"""Initialize. | ||
Args: | ||
model: A model | ||
X_evaluation_mask: A `q x m`-dim boolean tensor | ||
indicating which outcomes the decoupled acquisition | ||
function should generate new candidates for. | ||
""" | ||
if not isinstance(model, ModelList): | ||
raise ValueError(f"{self.__class__.__name__} requires using a ModelList.") | ||
super().__init__(model=model, **kwargs) | ||
self.num_outputs = model.num_outputs | ||
self.X_evaluation_mask = X_evaluation_mask | ||
self.X_pending_evaluation_mask = None | ||
self.X_pending = None | ||
|
||
@property | ||
def X_evaluation_mask(self) -> Optional[Tensor]: | ||
r"""Get the evaluation indices for the new candidate.""" | ||
return self._X_evaluation_mask | ||
|
||
@X_evaluation_mask.setter | ||
def X_evaluation_mask(self, X_evaluation_mask: Optional[Tensor] = None) -> None: | ||
r"""Set the evaluation indices for the new candidate.""" | ||
if X_evaluation_mask is not None: | ||
# TODO: Add batch support | ||
if ( | ||
X_evaluation_mask.ndim != 2 | ||
or X_evaluation_mask.shape[-1] != self.num_outputs | ||
): | ||
raise BotorchTensorDimensionError( | ||
"Expected X_evaluation_mask to be `q x m`, but got shape" | ||
f" {shape_to_str(X_evaluation_mask.shape)}." | ||
) | ||
self._X_evaluation_mask = X_evaluation_mask | ||
|
||
def set_X_pending( | ||
self, | ||
X_pending: Optional[Tensor] = None, | ||
X_pending_evaluation_mask: Optional[Tensor] = None, | ||
) -> None: | ||
r"""Informs the AF about pending design points for different outcomes. | ||
Args: | ||
X_pending: A `n' x d` Tensor with `n'` `d`-dim design points that have | ||
been submitted for evaluation but have not yet been evaluated. | ||
X_pending_evaluation_mask: A `n' x m`-dim tensor of booleans indicating | ||
for which outputs the pending point is being evaluated on. If | ||
`X_pending_evaluation_mask` is `None`, it is assumed that | ||
`X_pending` will be evaluated on all outcomes. | ||
""" | ||
if X_pending is not None: | ||
if X_pending.requires_grad: | ||
warnings.warn( | ||
"Pending points require a gradient but the acquisition function" | ||
" will not provide a gradient to these points.", | ||
BotorchWarning, | ||
) | ||
self.X_pending = X_pending.detach().clone() | ||
if X_pending_evaluation_mask is not None: | ||
if ( | ||
X_pending_evaluation_mask.ndim != 2 | ||
or X_pending_evaluation_mask.shape[0] != X_pending.shape[0] | ||
or X_pending_evaluation_mask.shape[1] != self.num_outputs | ||
): | ||
raise BotorchTensorDimensionError( | ||
f"Expected `X_pending_evaluation_mask` of shape " | ||
f"`{X_pending.shape[0]} x {self.num_outputs}`, but " | ||
f"got {shape_to_str(X_pending_evaluation_mask.shape)}." | ||
) | ||
self.X_pending_evaluation_mask = X_pending_evaluation_mask | ||
elif self.X_evaluation_mask is not None: | ||
raise ValueError( | ||
"If `self.X_evaluation_mask` is not None, then " | ||
"`X_pending_evaluation_mask` must be provided." | ||
) | ||
|
||
else: | ||
self.X_pending = X_pending | ||
self.X_pending_evaluation_mask = X_pending_evaluation_mask | ||
|
||
def construct_evaluation_mask(self, X: Tensor) -> Optional[Tensor]: | ||
r"""Construct the boolean evaluation mask for X and X_pending | ||
Args: | ||
X: A `batch_shape x n x d`-dim tensor of designs. | ||
Returns: | ||
A `n + n' x m`-dim tensor of booleans indicating | ||
which outputs should be evaluated. | ||
""" | ||
if self.X_pending_evaluation_mask is not None: | ||
X_evaluation_mask = self.X_evaluation_mask | ||
if X_evaluation_mask is None: | ||
# evaluate all objectives for X | ||
X_evaluation_mask = torch.ones( | ||
X.shape[-2], self.num_outputs, dtype=torch.bool, device=X.device | ||
) | ||
elif X_evaluation_mask.shape[0] != X.shape[-2]: | ||
raise BotorchTensorDimensionError( | ||
"Expected the -2 dimension of X and X_evaluation_mask to match." | ||
) | ||
# construct mask for X | ||
return torch.cat( | ||
[X_evaluation_mask, self.X_pending_evaluation_mask], dim=-2 | ||
) | ||
return self.X_evaluation_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import warnings | ||
|
||
import torch | ||
from botorch import settings | ||
from botorch.acquisition.decoupled import DecoupledAcquisitionFunction | ||
from botorch.exceptions import BotorchTensorDimensionError, BotorchWarning | ||
from botorch.logging import shape_to_str | ||
from botorch.models import ModelListGP, SingleTaskGP | ||
from botorch.utils.testing import BotorchTestCase | ||
|
||
|
||
class DummyDecoupledAcquisitionFunction(DecoupledAcquisitionFunction): | ||
def forward(self, X): | ||
pass | ||
|
||
|
||
class TestDecoupledAcquisitionFunction(BotorchTestCase): | ||
def test_decoupled_acquisition_function(self): | ||
msg = ( | ||
"Can't instantiate abstract class DecoupledAcquisitionFunction" | ||
" with abstract method forward" | ||
) | ||
with self.assertRaisesRegex(TypeError, msg): | ||
DecoupledAcquisitionFunction() | ||
# test raises error if model is not ModelList | ||
msg = "DummyDecoupledAcquisitionFunction requires using a ModelList." | ||
model = SingleTaskGP( | ||
torch.rand(1, 3, device=self.device), torch.rand(1, 2, device=self.device) | ||
) | ||
with self.assertRaisesRegex(ValueError, msg): | ||
DummyDecoupledAcquisitionFunction(model=model) | ||
m = SingleTaskGP( | ||
torch.rand(1, 3, device=self.device), torch.rand(1, 1, device=self.device) | ||
) | ||
model = ModelListGP(m, m) | ||
# basic test | ||
af = DummyDecoupledAcquisitionFunction(model=model) | ||
self.assertIs(af.model, model) | ||
self.assertIsNone(af.X_evaluation_mask) | ||
self.assertIsNone(af.X_pending) | ||
# test set X_evaluation_mask | ||
# test wrong number of outputs | ||
eval_mask = torch.randint(0, 2, (2, 3), device=self.device).bool() | ||
msg = ( | ||
"Expected X_evaluation_mask to be `q x m`, but got shape" | ||
f" {shape_to_str(eval_mask.shape)}." | ||
) | ||
with self.assertRaisesRegex(BotorchTensorDimensionError, msg): | ||
af.X_evaluation_mask = eval_mask | ||
# test more than 2 dimensions | ||
eval_mask.unsqueeze_(0) | ||
msg = ( | ||
"Expected X_evaluation_mask to be `q x m`, but got shape" | ||
f" {shape_to_str(eval_mask.shape)}." | ||
) | ||
with self.assertRaisesRegex(BotorchTensorDimensionError, msg): | ||
af.X_evaluation_mask = eval_mask | ||
|
||
# set eval_mask | ||
eval_mask = eval_mask[0, :, :2] | ||
af.X_evaluation_mask = eval_mask | ||
self.assertIs(af.X_evaluation_mask, eval_mask) | ||
|
||
# test set_X_pending | ||
X_pending = torch.rand(1, 1, device=self.device) | ||
msg = ( | ||
"If `self.X_evaluation_mask` is not None, then " | ||
"`X_pending_evaluation_mask` must be provided." | ||
) | ||
with self.assertRaisesRegex(ValueError, msg): | ||
af.set_X_pending(X_pending=X_pending) | ||
af.X_evaluation_mask = None | ||
X_pending = X_pending.requires_grad_(True) | ||
with warnings.catch_warnings(record=True) as ws, settings.debug(True): | ||
af.set_X_pending(X_pending) | ||
self.assertEqual(af.X_pending, X_pending) | ||
self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1) | ||
self.assertIsNone(af.X_evaluation_mask) | ||
|
||
# test setting X_pending with X_pending_evaluation_mask | ||
X_pending = torch.rand(3, 1, device=self.device) | ||
# test raises exception | ||
# wrong number of outputs, wrong number of dims, wrong number of rows | ||
for shape in ([3, 1], [1, 3, 2], [1, 2]): | ||
eval_mask = torch.randint(0, 2, shape, device=self.device).bool() | ||
msg = ( | ||
f"Expected `X_pending_evaluation_mask` of shape `{X_pending.shape[0]} " | ||
f"x {model.num_outputs}`, but got " | ||
f"{shape_to_str(eval_mask.shape)}." | ||
) | ||
|
||
with self.assertRaisesRegex(BotorchTensorDimensionError, msg): | ||
af.set_X_pending( | ||
X_pending=X_pending, X_pending_evaluation_mask=eval_mask | ||
) | ||
eval_mask = torch.randint(0, 2, (3, 2), device=self.device).bool() | ||
af.set_X_pending(X_pending=X_pending, X_pending_evaluation_mask=eval_mask) | ||
self.assertTrue(torch.equal(af.X_pending, X_pending)) | ||
self.assertIs(af.X_pending_evaluation_mask, eval_mask) | ||
|
||
# test construct_evaluation_mask | ||
# X_evaluation_mask is None | ||
X = torch.rand(4, 5, 2, device=self.device) | ||
X_eval_mask = af.construct_evaluation_mask(X=X) | ||
expected_eval_mask = torch.cat( | ||
[torch.ones(X.shape[1:], dtype=torch.bool, device=self.device), eval_mask], | ||
dim=0, | ||
) | ||
self.assertTrue(torch.equal(X_eval_mask, expected_eval_mask)) | ||
# test X_evaluation_mask is not None | ||
# test wrong shape | ||
af.X_evaluation_mask = torch.zeros(1, 2, dtype=bool, device=self.device) | ||
msg = "Expected the -2 dimension of X and X_evaluation_mask to match." | ||
with self.assertRaisesRegex(BotorchTensorDimensionError, msg): | ||
af.construct_evaluation_mask(X=X) | ||
af.X_evaluation_mask = torch.randint(0, 2, (5, 2), device=self.device).bool() | ||
X_eval_mask = af.construct_evaluation_mask(X=X) | ||
expected_eval_mask = torch.cat([af.X_evaluation_mask, eval_mask], dim=0) | ||
self.assertTrue(torch.equal(X_eval_mask, expected_eval_mask)) | ||
|
||
# test setting X_pending as None | ||
af.set_X_pending(X_pending=None, X_pending_evaluation_mask=None) | ||
self.assertIsNone(af.X_pending) | ||
self.assertIsNone(af.X_pending_evaluation_mask) | ||
|
||
# test construct_evaluation_mask when X_pending is None | ||
self.assertTrue( | ||
torch.equal(af.construct_evaluation_mask(X=X), af.X_evaluation_mask) | ||
) |