Skip to content

Commit

Permalink
NoisyExpectedImprovementMixin
Browse files Browse the repository at this point in the history
Summary: This commit introduces a mixin design to increase code sharing between `qLogNEI` and `qNEI`.

Differential Revision: D47511290

fbshipit-source-id: efe6b1550948c5df7da1e185c7770ba062d60b06
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jul 18, 2023
1 parent 8b8bbee commit 8fdcaaa
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 261 deletions.
151 changes: 11 additions & 140 deletions botorch/acquisition/logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,20 @@

from __future__ import annotations

from copy import deepcopy

from functools import partial

from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union

import torch
from botorch.acquisition.cached_cholesky import CachedCholeskyMCAcquisitionFunction
from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction
from botorch.acquisition.monte_carlo import (
NoisyExpectedImprovementMixin,
SampleReducingMCAcquisitionFunction,
)
from botorch.acquisition.objective import (
ConstrainedMCObjective,
MCAcquisitionObjective,
PosteriorTransform,
)
from botorch.acquisition.utils import (
compute_best_feasible_objective,
prune_inferior_points,
)
from botorch.exceptions.errors import BotorchError
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
Expand All @@ -37,7 +33,6 @@
logmeanexp,
smooth_amax,
)
from botorch.utils.transforms import match_batch_shape
from torch import Tensor

"""
Expand Down Expand Up @@ -227,7 +222,7 @@ def _sample_forward(self, obj: Tensor) -> Tensor:


class qLogNoisyExpectedImprovement(
LogImprovementMCAcquisitionFunction, CachedCholeskyMCAcquisitionFunction
LogImprovementMCAcquisitionFunction, NoisyExpectedImprovementMixin
):
r"""MC-based batch Log Noisy Expected Improvement.
Expand Down Expand Up @@ -308,9 +303,8 @@ def __init__(
the incremental q(Log)NEI from the new point. This would greatly increase
efficiency for large batches.
"""
# TODO: separate out baseline variables initialization and other functions
# in qNEI to avoid duplication of both code and work at runtime.
super().__init__(
LogImprovementMCAcquisitionFunction.__init__(
self,
model=model,
sampler=sampler,
objective=objective,
Expand All @@ -322,7 +316,8 @@ def __init__(
tau_max=tau_max,
)
self.tau_relu = tau_relu
self._init_baseline(
NoisyExpectedImprovementMixin.__init__(
self,
model=model,
X_baseline=X_baseline,
sampler=sampler,
Expand Down Expand Up @@ -350,133 +345,9 @@ def _sample_forward(self, obj: Tensor) -> Tensor:
fat=self._fat,
)

def _init_baseline(
self,
model: Model,
X_baseline: Tensor,
sampler: Optional[MCSampler] = None,
objective: Optional[MCAcquisitionObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
prune_baseline: bool = False,
cache_root: bool = True,
**kwargs: Any,
) -> None:
# setup of CachedCholeskyMCAcquisitionFunction
self._setup(model=model, cache_root=cache_root)
if prune_baseline:
X_baseline = prune_inferior_points(
model=model,
X=X_baseline,
objective=objective,
posterior_transform=posterior_transform,
marginalize_dim=kwargs.get("marginalize_dim"),
)
self.register_buffer("X_baseline", X_baseline)
# registering buffers for _get_samples_and_objectives in the next `if` block
self.register_buffer("baseline_samples", None)
self.register_buffer("baseline_obj", None)
if self._cache_root:
self.q_in = -1
# set baseline samples
with torch.no_grad(): # this is _get_samples_and_objectives(X_baseline)
posterior = self.model.posterior(
X_baseline, posterior_transform=self.posterior_transform
)
# Note: The root decomposition is cached in two different places. It
# may be confusing to have two different caches, but this is not
# trivial to change since each is needed for a different reason:
# - LinearOperator caching to `posterior.mvn` allows for reuse within
# this function, which may be helpful if the same root decomposition
# is produced by the calls to `self.base_sampler` and
# `self._cache_root_decomposition`.
# - self._baseline_L allows a root decomposition to be persisted outside
# this method.
self.baseline_samples = self.get_posterior_samples(posterior)
self.baseline_obj = self.objective(self.baseline_samples, X=X_baseline)

# We make a copy here because we will write an attribute `base_samples`
# to `self.base_sampler.base_samples`, and we don't want to mutate
# `self.sampler`.
self.base_sampler = deepcopy(self.sampler)
self.register_buffer(
"_baseline_best_f",
self._compute_best_feasible_objective(
samples=self.baseline_samples, obj=self.baseline_obj
),
)
self._baseline_L = self._compute_root_decomposition(posterior=posterior)

def compute_best_f(self, obj: Tensor) -> Tensor:
"""Computes the best (feasible) noisy objective value.
Args:
obj: `sample_shape x batch_shape x q`-dim Tensor of objectives in forward.
Returns:
A `sample_shape x batch_shape x 1`-dim Tensor of best feasible objectives.
"""
if self._cache_root:
val = self._baseline_best_f
else:
val = self._compute_best_feasible_objective(
samples=self.baseline_samples, obj=self.baseline_obj
)
# ensuring shape, dtype, device compatibility with obj
n_sample_dims = len(self.sample_shape)
view_shape = torch.Size(
[
*val.shape[:n_sample_dims], # sample dimensions
*(1,) * (obj.ndim - val.ndim), # pad to match obj
*val.shape[n_sample_dims:], # the rest
]
)
return val.view(view_shape).to(obj)

def _get_samples_and_objectives(self, X: Tensor) -> Tuple[Tensor, Tensor]:
r"""Compute samples at new points, using the cached root decomposition.
Args:
X: A `batch_shape x q x d`-dim tensor of inputs.
Returns:
A two-tuple `(samples, obj)`, where `samples` is a tensor of posterior
samples with shape `sample_shape x batch_shape x q x m`, and `obj` is a
tensor of MC objective values with shape `sample_shape x batch_shape x q`.
"""
n_baseline, q = self.X_baseline.shape[-2], X.shape[-2]
X_full = torch.cat([match_batch_shape(self.X_baseline, X), X], dim=-2)
# TODO: Implement more efficient way to compute posterior over both training and
# test points in GPyTorch (https://github.com/cornellius-gp/gpytorch/issues/567)
posterior = self.model.posterior(
X_full, posterior_transform=self.posterior_transform
)
if not self._cache_root:
samples_full = super().get_posterior_samples(posterior)
obj_full = self.objective(samples_full, X=X_full)
# assigning baseline buffers so `best_f` can be computed in _sample_forward
self.baseline_samples, samples = samples_full.split([n_baseline, q], dim=-2)
self.baseline_obj, obj = obj_full.split([n_baseline, q], dim=-1)
return samples, obj

# handle one-to-many input transforms
n_plus_q = X_full.shape[-2]
n_w = posterior._extended_shape()[-2] // n_plus_q
q_in = q * n_w
self._set_sampler(q_in=q_in, posterior=posterior)
samples = self._get_f_X_samples(posterior=posterior, q_in=q_in)
obj = self.objective(samples, X=X_full[..., -q:, :])
return samples, obj

def _compute_best_feasible_objective(self, samples: Tensor, obj: Tensor) -> Tensor:
return compute_best_feasible_objective(
samples=samples,
obj=obj,
constraints=self._constraints,
model=self.model,
objective=self.objective,
posterior_transform=self.posterior_transform,
X_baseline=self.X_baseline,
)
# Explicit, as both parent classes have this method, so no MRO magic required.
return NoisyExpectedImprovementMixin._get_samples_and_objectives(self, X)


"""
Expand Down
Loading

0 comments on commit 8fdcaaa

Please sign in to comment.