Skip to content

Commit

Permalink
SkewGP models (pytorch#1906)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1906

SkewGP models

Differential Revision: D47082855
  • Loading branch information
Balandat authored and facebook-github-bot committed Apr 10, 2024
1 parent e8cbbae commit 7809277
Show file tree
Hide file tree
Showing 7 changed files with 1,131 additions and 0 deletions.
100 changes: 100 additions & 0 deletions botorch/acquisition/expected_feasible_improvement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Optional

import torch
from botorch.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.exceptions.errors import UnsupportedError
from botorch.fit import fit_gpytorch_mll
from botorch.models.skew_gp import ExactMarginalLogLikelihood_v2, SkewGPClassifier
from botorch.utils.transforms import t_batch_mode_transform


class SkewGPClassifierMixin:
def __init__(
self,
X_pending: Optional[torch.Tensor] = None,
classifier: Optional[SkewGPClassifier] = None,
) -> None:
self.set_X_pending(X_pending)
self._classifier = classifier

# Fit classifier in advance to avoid it done inside the query optimizer's closure
self.classifier # TODO: Improve me

def set_X_pending(self, *args, **kwargs) -> None:
AcquisitionFunction.set_X_pending(self, *args, **kwargs)
self._classifier = None

@property
def classifier(self) -> Optional[SkewGPClassifier]:
if self._classifier is None and self.X_pending is not None:
X_succ = self.model.train_inputs[0]
X_fail = self.X_pending
# deal with multi-output SingleTaskGP models (which have an additional batch dim)
if X_succ.ndim > X_fail.ndim:
if not all((X_ == X_succ[0]).all() for X_ in X_succ[1:]):
# if we don't have a block design things are ambiguous - give up
raise UnsupportedError("Only block design models are supported")
X_succ = X_succ[0]
X = torch.cat([X_succ, X_fail], dim=0)
Y = torch.cat(
[
torch.full(X_succ.shape[:-1], True),
torch.full(X_fail.shape[:-1], False),
],
dim=0,
)
model = self._classifier = SkewGPClassifier(train_X=X, train_Y=Y)
fit_gpytorch_mll(ExactMarginalLogLikelihood_v2(model.likelihood, model))
return self._classifier


class ExpectedFeasibleImprovement(SkewGPClassifierMixin, ExpectedImprovement):
def __init__(
self,
*args,
X_pending: Optional[torch.Tensor] = None,
classifier: Optional[SkewGPClassifier] = None,
**kwargs,
):
ExpectedImprovement.__init__(self, *args, **kwargs)
SkewGPClassifierMixin.__init__(self, X_pending=X_pending, classifier=classifier)

@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
def forward(self, X: torch.Tensor) -> torch.Tensor:
ei = super().forward(X)
if self.classifier is None:
return ei

p_feas = self.classifier.posterior_predictive(X)
return p_feas.mean.view(ei.shape) * ei


class qExpectedFeasibleImprovement(SkewGPClassifierMixin, qExpectedImprovement):
def __init__(
self,
*args,
X_pending: Optional[torch.Tensor] = None,
classifier: Optional[SkewGPClassifier] = None,
**kwargs,
):
qExpectedImprovement.__init__(self, *args, **kwargs)
SkewGPClassifierMixin.__init__(self, X_pending=X_pending, classifier=classifier)

@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
def forward(self, X: torch.Tensor) -> torch.Tensor:
ei = super().forward(X)
if self.classifier is None:
return ei

p_feas = self.classifier.posterior_predictive(X)
return p_feas.mean.view(ei.shape) * ei
9 changes: 9 additions & 0 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
UpperConfidenceBound,
)
from botorch.acquisition.cost_aware import InverseCostWeightedUtility
from botorch.acquisition.expected_feasible_improvement import (
ExpectedFeasibleImprovement,
qExpectedFeasibleImprovement,
)
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
from botorch.acquisition.knowledge_gradient import (
Expand Down Expand Up @@ -1562,3 +1566,8 @@ def construct_inputs_qJES(
"num_samples": num_samples,
}
return inputs


@acqf_input_constructor(ExpectedFeasibleImprovement)
def _construct_inputs_efi(*args, X_pending=None, **kwargs):
return {"X_pending": X_pending, **construct_inputs_best_f(*args, **kwargs)}
3 changes: 3 additions & 0 deletions botorch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MultiTaskGP,
)
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
from botorch.models.skew_gp import SkewGP, SkewGPClassifier

__all__ = [
"AffineDeterministicModel",
Expand All @@ -56,4 +57,6 @@
"SingleTaskGP",
"SingleTaskMultiFidelityGP",
"SingleTaskVariationalGP",
"SkewGP",
"SkewGPClassifier",
]
2 changes: 2 additions & 0 deletions botorch/models/likelihoods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from botorch.models.likelihoods.affine_probit import AffineProbitLikelihood
from botorch.models.likelihoods.pairwise import (
PairwiseLogitLikelihood,
PairwiseProbitLikelihood,
)


__all__ = [
"AffineProbitLikelihood",
"PairwiseProbitLikelihood",
"PairwiseLogitLikelihood",
]
134 changes: 134 additions & 0 deletions botorch/models/likelihoods/affine_probit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Optional, Tuple, Union

import torch
from botorch.utils.probability import TruncatedMultivariateNormal, UnifiedSkewNormal
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import _GaussianLikelihoodBase
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.likelihoods.noise_models import Noise
from linear_operator.operators import DiagLinearOperator, LinearOperator
from torch import BoolTensor, Tensor
from torch.nn.functional import pad


class AffineProbitLikelihood(_GaussianLikelihoodBase, Likelihood):
def __init__(
self,
weight: Union[LinearOperator, Tensor],
bias: Optional[Union[LinearOperator, Tensor]] = None,
noise_covar: Optional[Noise] = None,
):
"""Affine probit likelihood `P(f + e > 0)`, where `f = Ax + b` is as an affine
transformation of an `n`-dimensional Gaussian random vector `x` and `e ~ Noise`
is an `m`-dimensional centered, Gaussian noise vector.
Args:
weight: Matrix `A` with shape (... x n x m).
bias: Vector `b` with shape (... x m).
noise_covar: Noise covariance matrix with shape (... x m x m).
"""
Likelihood.__init__(self)
self.weight = weight
self.bias = bias
self.noise_covar = noise_covar

def get_affine_transform(
self, diag: Optional[Tensor] = None
) -> Tuple[Union[Tensor, LinearOperator], Optional[Union[Tensor, LinearOperator]]]:
"""Returns the base affine transform with sign flips for negative labels.
Args:
diag: Scaling factors `d` for the affine transform such that (DA, Db) is
returned, where `D = diag(d)`.
Returns:
Tensor representation of the affine transform (A, b).
"""
if diag is None:
return self.weight, self.bias

D = DiagLinearOperator(diag)
return D @ self.weight, None if (self.bias is None) else D @ self.bias

def marginal(
self,
function_dist: MultivariateNormal,
observations: Optional[BoolTensor] = None,
) -> TruncatedMultivariateNormal:
"""Returns the truncated multivariate normal distribution of `h | h > 0`, where
`x` is a Gaussian random vector, `h = (Ax + b) + e`, and `e ~ Noise`."""
gauss_loc = function_dist.loc
gauss_cov = function_dist.covariance_matrix
signed_labels = (
None
if observations is None
else 2 * observations.to(dtype=gauss_loc.dtype, device=gauss_loc.device) - 1
)

A, b = self.get_affine_transform(diag=signed_labels)
trunc_loc = A @ gauss_loc if (b is None) else A @ gauss_loc + b
trunc_cov = A @ gauss_cov @ A.transpose(-1, -2)
if self.noise_covar is not None:
noise_diag = self.noise_covar(shape=trunc_cov.shape[:-1])
trunc_cov = (trunc_cov + noise_diag).to_dense()

return TruncatedMultivariateNormal(
loc=trunc_loc,
covariance_matrix=trunc_cov,
bounds=pad(torch.full_like(trunc_loc, float("inf")).unsqueeze(-1), (1, 0)),
validate_args=False,
)

def log_marginal(
self,
observations: BoolTensor,
function_dist: MultivariateNormal,
) -> Tensor:
"""Returns the log marginal likelihood `ln p(y) = ln P([2y - 1](f + e) > 0)`,
where `f = Ax + b` and `e ~ Noise`."""
return self.marginal(function_dist, observations=observations).log_partition

def latent_marginal(
self,
function_dist: MultivariateNormal,
observations: Optional[BoolTensor] = None,
) -> UnifiedSkewNormal:
"""Returns the UnifiedSkewNormal distribution of `x | f + e > 0`, where
`x` is a Gaussian random vector, `f = Ax + b`, and `e ~ Noise`."""
gauss_loc = function_dist.loc
gauss_cov = function_dist.covariance_matrix
signed_labels = (
None
if observations is None
else 2 * observations.to(dtype=gauss_loc.dtype, device=gauss_loc.device) - 1
)

A, b = self.get_affine_transform(diag=signed_labels)
trunc_loc = A @ gauss_loc if (b is None) else A @ gauss_loc + b
cross_cov = A @ gauss_cov
trunc_cov = cross_cov @ A.transpose(-1, -2)
if self.noise_covar is not None:
noise_diag = self.noise_covar(shape=trunc_cov.shape[:-1])
trunc_cov = (trunc_cov + noise_diag).to_dense()

trunc = TruncatedMultivariateNormal(
loc=trunc_loc,
covariance_matrix=trunc_cov,
bounds=pad(torch.full_like(trunc_loc, float("inf")).unsqueeze(-1), (1, 0)),
validate_args=False,
)

return UnifiedSkewNormal(
trunc=trunc,
gauss=function_dist,
cross_covariance_matrix=cross_cov,
)
Loading

0 comments on commit 7809277

Please sign in to comment.