Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add isinstance_af #1664

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions botorch/acquisition/fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor
from torch.nn import Module


class FixedFeatureAcquisitionFunction(AcquisitionFunction):
class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
"""A wrapper around AquisitionFunctions to fix a subset of features.

Example:
Expand Down Expand Up @@ -56,8 +56,7 @@ def __init__(
combination of `Tensor`s and numbers which can be broadcasted
to form a tensor with trailing dimension size of `d_f`.
"""
Module.__init__(self)
self.acq_func = acq_function
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
dtype = torch.float
device = torch.device("cpu")
self.d = d
Expand Down Expand Up @@ -126,24 +125,13 @@ def forward(self, X: Tensor):
X_full = self._construct_X_full(X)
return self.acq_func(X_full)

@property
def X_pending(self):
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

@X_pending.setter
def X_pending(self, X_pending: Optional[Tensor]):
def set_X_pending(self, X_pending: Optional[Tensor]):
r"""Sets the `X_pending` of the base acquisition function."""
if X_pending is not None:
self.acq_func.X_pending = self._construct_X_full(X_pending)
full_X_pending = self._construct_X_full(X_pending)
else:
self.acq_func.X_pending = X_pending
full_X_pending = None
self.acq_func.set_X_pending(full_X_pending)

def _construct_X_full(self, X: Tensor) -> Tensor:
r"""Constructs the full input for the base acquisition function.
Expand Down
24 changes: 5 additions & 19 deletions botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.acquisition.objective import GenericMCObjective
from botorch.exceptions import UnsupportedError
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor


Expand Down Expand Up @@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
return regularization_term


class PenalizedAcquisitionFunction(AcquisitionFunction):
class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
r"""Single-outcome acquisition function regularized by the given penalty.

The usage is similar to:
Expand All @@ -161,29 +160,16 @@ def __init__(
penalty_func: The regularization function.
regularization_parameter: Regularization parameter used in optimization.
"""
super().__init__(model=raw_acqf.model)
self.raw_acqf = raw_acqf
AcquisitionFunction.__init__(self, model=raw_acqf.model)
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf)
self.penalty_func = penalty_func
self.regularization_parameter = regularization_parameter

def forward(self, X: Tensor) -> Tensor:
raw_value = self.raw_acqf(X=X)
raw_value = self.acq_func(X=X)
penalty_term = self.penalty_func(X)
return raw_value - self.regularization_parameter * penalty_term

@property
def X_pending(self) -> Optional[Tensor]:
return self.raw_acqf.X_pending

def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
self.raw_acqf.set_X_pending(X_pending=X_pending)
else:
raise UnsupportedError(
"The raw acquisition function is Analytic and does not account "
"for X_pending yet."
)


def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
r"""Computes the group lasso regularization function for the given point.
Expand Down
15 changes: 10 additions & 5 deletions botorch/acquisition/proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch
from botorch.acquisition import AcquisitionFunction

from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from botorch.exceptions.errors import UnsupportedError
from botorch.models import ModelListGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
Expand All @@ -25,7 +27,7 @@
from torch.nn import Module


class ProximalAcquisitionFunction(AcquisitionFunction):
class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
"""A wrapper around AcquisitionFunctions to add proximal weighting of the
acquisition function. The acquisition function is
weighted via a squared exponential centered at the last training point,
Expand Down Expand Up @@ -70,17 +72,14 @@ def __init__(
beta: If not None, apply a softplus transform to the base acquisition
function, allows negative base acquisition function values.
"""
Module.__init__(self)

self.acq_func = acq_function
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
model = self.acq_func.model

if hasattr(acq_function, "X_pending"):
if acq_function.X_pending is not None:
raise UnsupportedError(
"Proximal acquisition function requires `X_pending` to be None."
)
self.X_pending = acq_function.X_pending

self.register_buffer("proximal_weights", proximal_weights)
self.register_buffer(
Expand All @@ -91,6 +90,12 @@ def __init__(

_validate_model(model, proximal_weights)

def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
r"""Sets the `X_pending` of the base acquisition function."""
raise UnsupportedError(
"Proximal acquisition function does not support `X_pending`."
)

@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate base acquisition function with proximal weighting.
Expand Down
17 changes: 15 additions & 2 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from __future__ import annotations

import math
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401
Expand All @@ -22,6 +22,7 @@
MCAcquisitionObjective,
PosteriorTransform,
)
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from botorch.exceptions.errors import UnsupportedError
from botorch.models.fully_bayesian import MCMC_DIM
from botorch.models.model import Model
Expand Down Expand Up @@ -253,6 +254,18 @@ def objective(Y: Tensor, X: Optional[Tensor] = None):
return -(lb.clamp_max(0.0))


def isinstance_af(
__obj: object,
__class_or_tuple: Union[type, tuple[Union[type, tuple[Any, ...]], ...]],
) -> bool:
r"""A variant of isinstance first checks for the acq_func attribute on wrapped acquisition functions."""
if isinstance(__obj, AbstractAcquisitionFunctionWrapper):
isinstance_base_af = isinstance(__obj.acq_func, __class_or_tuple)
else:
isinstance_base_af = False
return isinstance_base_af or isinstance(__obj, __class_or_tuple)


def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
r"""Determine whether a given acquisition function is non-negative.

Expand All @@ -267,7 +280,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
>>> qEI = qExpectedImprovement(model, best_f=0.1)
>>> is_nonnegative(qEI) # returns True
"""
return isinstance(
return isinstance_af(
acq_function,
(
analytic.ExpectedImprovement,
Expand Down
55 changes: 55 additions & 0 deletions botorch/acquisition/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
A wrapper classes around AcquisitionFunctions to modify inputs and outputs.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional

from botorch.acquisition.acquisition import AcquisitionFunction
from torch import Tensor
from torch.nn import Module


class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC):
r"""Abstract acquisition wrapper."""

def __init__(self, acq_function: AcquisitionFunction) -> None:
Module.__init__(self)
self.acq_func = acq_function

@property
def X_pending(self) -> Optional[Tensor]:
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
r"""Sets the `X_pending` of the base acquisition function."""
self.acq_func.set_X_pending(X_pending)

@abstractmethod
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate the wrapped acquisition function on the candidate set X.

Args:
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
design points each.

Returns:
A `(b)`-dim Tensor of acquisition function values at the given
design points `X`.
"""
pass # pragma: no cover
9 changes: 7 additions & 2 deletions sphinx/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ Analytic Acquisition Function API
.. autoclass:: AnalyticAcquisitionFunction
:members:

Acquisition Function Wrapper API
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.wrapper
:members:

Cached Cholesky Acquisition Function API
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.cached_cholesky
Expand Down Expand Up @@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions
.. automodule:: botorch.acquisition.multi_objective.analytic
:members:
:exclude-members: MultiObjectiveAnalyticAcquisitionFunction

Multi-Objective Joint Entropy Search Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.joint_entropy_search
Expand All @@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.multi_fidelity
:members:

Multi-Objective Predictive Entropy Search Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search
Expand Down
2 changes: 1 addition & 1 deletion test/acquisition/test_fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_fixed_features(self):
qEI_ff.set_X_pending(X_pending[..., :-1])
self.assertAllClose(qEI.X_pending, X_pending)
# test setting to None
qEI_ff.X_pending = None
qEI_ff.set_X_pending(None)
self.assertIsNone(qEI_ff.X_pending)

# test gradient
Expand Down
8 changes: 7 additions & 1 deletion test/acquisition/test_proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,15 @@ def test_proximal(self):

# test for x_pending points
pending_acq = DummyAcquisitionFunction(model)
pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype))
X_pending = torch.rand(3, 3, device=self.device, dtype=dtype)
pending_acq.set_X_pending(X_pending)
with self.assertRaises(UnsupportedError):
ProximalAcquisitionFunction(pending_acq, proximal_weights)
# test setting pending points
pending_acq.set_X_pending(None)
af = ProximalAcquisitionFunction(pending_acq, proximal_weights)
with self.assertRaises(UnsupportedError):
af.set_X_pending(X_pending)

# test model with multi-batch training inputs
train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype)
Expand Down
61 changes: 60 additions & 1 deletion test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from unittest import mock

import torch
from botorch.acquisition import monte_carlo
from botorch.acquisition import analytic, monte_carlo, multi_objective
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.multi_objective import (
MCMultiOutputObjective,
monte_carlo as moo_monte_carlo,
Expand All @@ -18,10 +19,13 @@
MCAcquisitionObjective,
ScalarizedPosteriorTransform,
)
from botorch.acquisition.proximal import ProximalAcquisitionFunction
from botorch.acquisition.utils import (
expand_trace_observations,
get_acquisition_function,
get_infeasible_cost,
is_nonnegative,
isinstance_af,
project_to_sample_points,
project_to_target_fidelity,
prune_inferior_points,
Expand Down Expand Up @@ -606,6 +610,61 @@ def test_get_infeasible_cost(self):
self.assertAllClose(M4, torch.tensor([1.0], **tkwargs))


class TestIsNonnegative(BotorchTestCase):
def test_is_nonnegative(self):
nonneg_afs = (
analytic.ExpectedImprovement,
analytic.ConstrainedExpectedImprovement,
analytic.ProbabilityOfImprovement,
analytic.NoisyExpectedImprovement,
monte_carlo.qExpectedImprovement,
monte_carlo.qNoisyExpectedImprovement,
monte_carlo.qProbabilityOfImprovement,
multi_objective.analytic.ExpectedHypervolumeImprovement,
multi_objective.monte_carlo.qExpectedHypervolumeImprovement,
multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement,
)
mm = MockModel(
MockPosterior(
mean=torch.rand(1, 1, device=self.device),
variance=torch.ones(1, 1, device=self.device),
)
)
acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0)
with mock.patch(
"botorch.acquisition.utils.isinstance_af", return_value=True
) as mock_isinstance_af:
self.assertTrue(is_nonnegative(acq_function=acq_func))
mock_isinstance_af.assert_called_once()
cargs, _ = mock_isinstance_af.call_args
self.assertIs(cargs[0], acq_func)
self.assertEqual(cargs[1], nonneg_afs)
acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0)
self.assertFalse(is_nonnegative(acq_function=acq_func))


class TestIsinstanceAf(BotorchTestCase):
def test_isinstance_af(self):
mm = MockModel(
MockPosterior(
mean=torch.rand(1, 1, device=self.device),
variance=torch.ones(1, 1, device=self.device),
)
)
acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0)
self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement))
self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound))
wrapped_af = FixedFeatureAcquisitionFunction(
acq_function=acq_func, d=2, columns=[1], values=[0.0]
)
# test base af class
self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement))
self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound))
# test wrapper class
self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction))
self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction))


class TestPruneInferiorPoints(BotorchTestCase):
def test_prune_inferior_points(self):
for dtype in (torch.float, torch.double):
Expand Down
Loading