Skip to content

Commit

Permalink
Decoupled Acquisition Function (#1948)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1948

Introduce an abstract class for decoupled acquisition functions.

A decoupled acquisition function where one may intend to evaluate a design on only a subset of the outcomes. Typically this would be handled by fantasizing, where one would fantasize as to what the partial observation would be if one were to evaluate a design on the subset of outcomes (e.g. you only fantasize at those outcomes)

Reviewed By: esantorella

Differential Revision: D47710904

fbshipit-source-id: e61b3555c5fd93b53990ce3af299650bbb5341e1
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jul 27, 2023
1 parent cca54db commit abe786a
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 0 deletions.
2 changes: 2 additions & 0 deletions botorch/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
GenericCostAwareUtility,
InverseCostWeightedUtility,
)
from botorch.acquisition.decoupled import DecoupledAcquisitionFunction
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.input_constructors import get_acqf_input_constructor
from botorch.acquisition.knowledge_gradient import (
Expand Down Expand Up @@ -78,6 +79,7 @@
"AnalyticAcquisitionFunction",
"AnalyticExpectedUtilityOfBestOption",
"ConstrainedExpectedImprovement",
"DecoupledAcquisitionFunction",
"ExpectedImprovement",
"LogExpectedImprovement",
"LogNoisyExpectedImprovement",
Expand Down
163 changes: 163 additions & 0 deletions botorch/acquisition/decoupled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Abstract base module for decoupled acquisition functions."""

from __future__ import annotations

import warnings
from abc import ABC
from typing import Optional

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.exceptions import BotorchWarning
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.logging import shape_to_str

from botorch.models.model import ModelList
from torch import Tensor


class DecoupledAcquisitionFunction(AcquisitionFunction, ABC):
"""
Abstract base class for decoupled acquisition functions.
A decoupled acquisition function where one may intend to
evaluate a design on only a subset of the outcomes.
Typically this would be handled by fantasizing, where one
would fantasize as to what the partial observation would
be if one were to evaluate a design on the subset of
outcomes (e.g. you only fantasize at those outcomes). The
`X_evaluation_mask` specifies which outcomes should be
evaluated for each design. `X_evaluation_mask` is `q x m`,
where there are q design points in the batch and m outcomes.
In the asynchronous case, where there are n' pending points,
we need to track which outcomes each pending point should be
evaluated on. In this case, we concatenate
`X_pending_evaluation_mask` with `X_evaluation_mask` to obtain
the full evaluation_mask.
This abstract class handles generating and updating an evaluation mask,
which is a boolean tensor indicating which outcomes a given design is
being evaluated on. The evaluation mask has shape `(n' + q) x m`, where
n' is the number of pending points and the q represents the new
candidates to be generated.
If `X(_pending)_evaluation_mas`k is None, it is assumed that `X(_pending)`
will be evaluated on all outcomes.
"""

def __init__(
self, model: ModelList, X_evaluation_mask: Optional[Tensor] = None, **kwargs
) -> None:
r"""Initialize.
Args:
model: A model
X_evaluation_mask: A `q x m`-dim boolean tensor
indicating which outcomes the decoupled acquisition
function should generate new candidates for.
"""
if not isinstance(model, ModelList):
raise ValueError(f"{self.__class__.__name__} requires using a ModelList.")
super().__init__(model=model, **kwargs)
self.num_outputs = model.num_outputs
self.X_evaluation_mask = X_evaluation_mask
self.X_pending_evaluation_mask = None
self.X_pending = None

@property
def X_evaluation_mask(self) -> Optional[Tensor]:
r"""Get the evaluation indices for the new candidate."""
return self._X_evaluation_mask

@X_evaluation_mask.setter
def X_evaluation_mask(self, X_evaluation_mask: Optional[Tensor] = None) -> None:
r"""Set the evaluation indices for the new candidate."""
if X_evaluation_mask is not None:
# TODO: Add batch support
if (
X_evaluation_mask.ndim != 2
or X_evaluation_mask.shape[-1] != self.num_outputs
):
raise BotorchTensorDimensionError(
"Expected X_evaluation_mask to be `q x m`, but got shape"
f" {shape_to_str(X_evaluation_mask.shape)}."
)
self._X_evaluation_mask = X_evaluation_mask

def set_X_pending(
self,
X_pending: Optional[Tensor] = None,
X_pending_evaluation_mask: Optional[Tensor] = None,
) -> None:
r"""Informs the AF about pending design points for different outcomes.
Args:
X_pending: A `n' x d` Tensor with `n'` `d`-dim design points that have
been submitted for evaluation but have not yet been evaluated.
X_pending_evaluation_mask: A `n' x m`-dim tensor of booleans indicating
for which outputs the pending point is being evaluated on. If
`X_pending_evaluation_mask` is `None`, it is assumed that
`X_pending` will be evaluated on all outcomes.
"""
if X_pending is not None:
if X_pending.requires_grad:
warnings.warn(
"Pending points require a gradient but the acquisition function"
" will not provide a gradient to these points.",
BotorchWarning,
)
self.X_pending = X_pending.detach().clone()
if X_pending_evaluation_mask is not None:
if (
X_pending_evaluation_mask.ndim != 2
or X_pending_evaluation_mask.shape[0] != X_pending.shape[0]
or X_pending_evaluation_mask.shape[1] != self.num_outputs
):
raise BotorchTensorDimensionError(
f"Expected `X_pending_evaluation_mask` of shape "
f"`{X_pending.shape[0]} x {self.num_outputs}`, but "
f"got {shape_to_str(X_pending_evaluation_mask.shape)}."
)
self.X_pending_evaluation_mask = X_pending_evaluation_mask
elif self.X_evaluation_mask is not None:
raise ValueError(
"If `self.X_evaluation_mask` is not None, then "
"`X_pending_evaluation_mask` must be provided."
)

else:
self.X_pending = X_pending
self.X_pending_evaluation_mask = X_pending_evaluation_mask

def construct_evaluation_mask(self, X: Tensor) -> Optional[Tensor]:
r"""Construct the boolean evaluation mask for X and X_pending
Args:
X: A `batch_shape x n x d`-dim tensor of designs.
Returns:
A `n + n' x m`-dim tensor of booleans indicating
which outputs should be evaluated.
"""
if self.X_pending_evaluation_mask is not None:
X_evaluation_mask = self.X_evaluation_mask
if X_evaluation_mask is None:
# evaluate all objectives for X
X_evaluation_mask = torch.ones(
X.shape[-2], self.num_outputs, dtype=torch.bool, device=X.device
)
elif X_evaluation_mask.shape[0] != X.shape[-2]:
raise BotorchTensorDimensionError(
"Expected the -2 dimension of X and X_evaluation_mask to match."
)
# construct mask for X
return torch.cat(
[X_evaluation_mask, self.X_pending_evaluation_mask], dim=-2
)
return self.X_evaluation_mask
5 changes: 5 additions & 0 deletions sphinx/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ Cached Cholesky Acquisition Function API
.. automodule:: botorch.acquisition.cached_cholesky
:members:

Decoupled Acquisition Function API
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.decoupled
:members:

Monte-Carlo Acquisition Function API
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: botorch.acquisition.monte_carlo
Expand Down
135 changes: 135 additions & 0 deletions test/acquisition/test_decoupled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import warnings

import torch
from botorch import settings
from botorch.acquisition.decoupled import DecoupledAcquisitionFunction
from botorch.exceptions import BotorchTensorDimensionError, BotorchWarning
from botorch.logging import shape_to_str
from botorch.models import ModelListGP, SingleTaskGP
from botorch.utils.testing import BotorchTestCase


class DummyDecoupledAcquisitionFunction(DecoupledAcquisitionFunction):
def forward(self, X):
pass


class TestDecoupledAcquisitionFunction(BotorchTestCase):
def test_decoupled_acquisition_function(self):
msg = (
"Can't instantiate abstract class DecoupledAcquisitionFunction"
" with abstract method forward"
)
with self.assertRaisesRegex(TypeError, msg):
DecoupledAcquisitionFunction()
# test raises error if model is not ModelList
msg = "DummyDecoupledAcquisitionFunction requires using a ModelList."
model = SingleTaskGP(
torch.rand(1, 3, device=self.device), torch.rand(1, 2, device=self.device)
)
with self.assertRaisesRegex(ValueError, msg):
DummyDecoupledAcquisitionFunction(model=model)
m = SingleTaskGP(
torch.rand(1, 3, device=self.device), torch.rand(1, 1, device=self.device)
)
model = ModelListGP(m, m)
# basic test
af = DummyDecoupledAcquisitionFunction(model=model)
self.assertIs(af.model, model)
self.assertIsNone(af.X_evaluation_mask)
self.assertIsNone(af.X_pending)
# test set X_evaluation_mask
# test wrong number of outputs
eval_mask = torch.randint(0, 2, (2, 3), device=self.device).bool()
msg = (
"Expected X_evaluation_mask to be `q x m`, but got shape"
f" {shape_to_str(eval_mask.shape)}."
)
with self.assertRaisesRegex(BotorchTensorDimensionError, msg):
af.X_evaluation_mask = eval_mask
# test more than 2 dimensions
eval_mask.unsqueeze_(0)
msg = (
"Expected X_evaluation_mask to be `q x m`, but got shape"
f" {shape_to_str(eval_mask.shape)}."
)
with self.assertRaisesRegex(BotorchTensorDimensionError, msg):
af.X_evaluation_mask = eval_mask

# set eval_mask
eval_mask = eval_mask[0, :, :2]
af.X_evaluation_mask = eval_mask
self.assertIs(af.X_evaluation_mask, eval_mask)

# test set_X_pending
X_pending = torch.rand(1, 1, device=self.device)
msg = (
"If `self.X_evaluation_mask` is not None, then "
"`X_pending_evaluation_mask` must be provided."
)
with self.assertRaisesRegex(ValueError, msg):
af.set_X_pending(X_pending=X_pending)
af.X_evaluation_mask = None
X_pending = X_pending.requires_grad_(True)
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
af.set_X_pending(X_pending)
self.assertEqual(af.X_pending, X_pending)
self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1)
self.assertIsNone(af.X_evaluation_mask)

# test setting X_pending with X_pending_evaluation_mask
X_pending = torch.rand(3, 1, device=self.device)
# test raises exception
# wrong number of outputs, wrong number of dims, wrong number of rows
for shape in ([3, 1], [1, 3, 2], [1, 2]):
eval_mask = torch.randint(0, 2, shape, device=self.device).bool()
msg = (
f"Expected `X_pending_evaluation_mask` of shape `{X_pending.shape[0]} "
f"x {model.num_outputs}`, but got "
f"{shape_to_str(eval_mask.shape)}."
)

with self.assertRaisesRegex(BotorchTensorDimensionError, msg):
af.set_X_pending(
X_pending=X_pending, X_pending_evaluation_mask=eval_mask
)
eval_mask = torch.randint(0, 2, (3, 2), device=self.device).bool()
af.set_X_pending(X_pending=X_pending, X_pending_evaluation_mask=eval_mask)
self.assertTrue(torch.equal(af.X_pending, X_pending))
self.assertIs(af.X_pending_evaluation_mask, eval_mask)

# test construct_evaluation_mask
# X_evaluation_mask is None
X = torch.rand(4, 5, 2, device=self.device)
X_eval_mask = af.construct_evaluation_mask(X=X)
expected_eval_mask = torch.cat(
[torch.ones(X.shape[1:], dtype=torch.bool, device=self.device), eval_mask],
dim=0,
)
self.assertTrue(torch.equal(X_eval_mask, expected_eval_mask))
# test X_evaluation_mask is not None
# test wrong shape
af.X_evaluation_mask = torch.zeros(1, 2, dtype=bool, device=self.device)
msg = "Expected the -2 dimension of X and X_evaluation_mask to match."
with self.assertRaisesRegex(BotorchTensorDimensionError, msg):
af.construct_evaluation_mask(X=X)
af.X_evaluation_mask = torch.randint(0, 2, (5, 2), device=self.device).bool()
X_eval_mask = af.construct_evaluation_mask(X=X)
expected_eval_mask = torch.cat([af.X_evaluation_mask, eval_mask], dim=0)
self.assertTrue(torch.equal(X_eval_mask, expected_eval_mask))

# test setting X_pending as None
af.set_X_pending(X_pending=None, X_pending_evaluation_mask=None)
self.assertIsNone(af.X_pending)
self.assertIsNone(af.X_pending_evaluation_mask)

# test construct_evaluation_mask when X_pending is None
self.assertTrue(
torch.equal(af.construct_evaluation_mask(X=X), af.X_evaluation_mask)
)

0 comments on commit abe786a

Please sign in to comment.