-
Notifications
You must be signed in to change notification settings - Fork 404
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: SkewGP models Differential Revision: D47082855 fbshipit-source-id: 5a7f5ac17bf14623cbf876cc50dd5e40ffe210ed
- Loading branch information
1 parent
f94a191
commit d4acc11
Showing
7 changed files
with
1,132 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from botorch.acquisition import AcquisitionFunction | ||
from botorch.acquisition.analytic import ExpectedImprovement | ||
from botorch.acquisition.monte_carlo import qExpectedImprovement | ||
from botorch.exceptions.errors import UnsupportedError | ||
from botorch.fit import fit_gpytorch_mll | ||
from botorch.models.skew_gp import ExactMarginalLogLikelihood_v2, SkewGPClassifier | ||
from botorch.utils.transforms import t_batch_mode_transform | ||
|
||
|
||
class SkewGPClassifierMixin: | ||
def __init__( | ||
self, | ||
X_pending: Optional[torch.Tensor] = None, | ||
classifier: Optional[SkewGPClassifier] = None, | ||
) -> None: | ||
self.set_X_pending(X_pending) | ||
self._classifier = classifier | ||
|
||
# Fit classifier in advance to avoid it done inside the query optimizer's closure | ||
self.classifier # TODO: Improve me | ||
|
||
def set_X_pending(self, *args, **kwargs) -> None: | ||
AcquisitionFunction.set_X_pending(self, *args, **kwargs) | ||
self._classifier = None | ||
|
||
@property | ||
def classifier(self) -> Optional[SkewGPClassifier]: | ||
if self._classifier is None and self.X_pending is not None: | ||
X_succ = self.model.train_inputs[0] | ||
X_fail = self.X_pending | ||
# deal with multi-output SingleTaskGP models (which have an additional batch dim) | ||
if X_succ.ndim > X_fail.ndim: | ||
if not all((X_ == X_succ[0]).all() for X_ in X_succ[1:]): | ||
# if we don't have a block design things are ambiguous - give up | ||
raise UnsupportedError("Only block design models are supported") | ||
X_succ = X_succ[0] | ||
X = torch.cat([X_succ, X_fail], dim=0) | ||
Y = torch.cat( | ||
[ | ||
torch.full(X_succ.shape[:-1], True), | ||
torch.full(X_fail.shape[:-1], False), | ||
], | ||
dim=0, | ||
) | ||
model = self._classifier = SkewGPClassifier(train_X=X, train_Y=Y) | ||
fit_gpytorch_mll(ExactMarginalLogLikelihood_v2(model.likelihood, model)) | ||
return self._classifier | ||
|
||
|
||
class ExpectedFeasibleImprovement(SkewGPClassifierMixin, ExpectedImprovement): | ||
def __init__( | ||
self, | ||
*args, | ||
X_pending: Optional[torch.Tensor] = None, | ||
classifier: Optional[SkewGPClassifier] = None, | ||
**kwargs, | ||
): | ||
ExpectedImprovement.__init__(self, *args, **kwargs) | ||
SkewGPClassifierMixin.__init__(self, X_pending=X_pending, classifier=classifier) | ||
|
||
@t_batch_mode_transform(expected_q=1, assert_output_shape=False) | ||
def forward(self, X: torch.Tensor) -> torch.Tensor: | ||
ei = super().forward(X) | ||
if self.classifier is None: | ||
return ei | ||
|
||
p_feas = self.classifier.posterior_predictive(X) | ||
return p_feas.mean.view(ei.shape) * ei | ||
|
||
|
||
class qExpectedFeasibleImprovement(SkewGPClassifierMixin, qExpectedImprovement): | ||
def __init__( | ||
self, | ||
*args, | ||
X_pending: Optional[torch.Tensor] = None, | ||
classifier: Optional[SkewGPClassifier] = None, | ||
**kwargs, | ||
): | ||
qExpectedImprovement.__init__(self, *args, **kwargs) | ||
SkewGPClassifierMixin.__init__(self, X_pending=X_pending, classifier=classifier) | ||
|
||
@t_batch_mode_transform(expected_q=1, assert_output_shape=False) | ||
def forward(self, X: torch.Tensor) -> torch.Tensor: | ||
ei = super().forward(X) | ||
if self.classifier is None: | ||
return ei | ||
|
||
p_feas = self.classifier.posterior_predictive(X) | ||
return p_feas.mean.view(ei.shape) * ei |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Optional, Tuple, Union | ||
|
||
import torch | ||
from botorch.utils.probability import TruncatedMultivariateNormal, UnifiedSkewNormal | ||
from gpytorch.distributions import MultivariateNormal | ||
from gpytorch.likelihoods import _GaussianLikelihoodBase | ||
from gpytorch.likelihoods.likelihood import Likelihood | ||
from gpytorch.likelihoods.noise_models import Noise | ||
from linear_operator.operators import DiagLinearOperator, LinearOperator | ||
from torch import BoolTensor, Tensor | ||
from torch.nn.functional import pad | ||
|
||
|
||
class AffineProbitLikelihood(_GaussianLikelihoodBase, Likelihood): | ||
def __init__( | ||
self, | ||
weight: Union[LinearOperator, Tensor], | ||
bias: Optional[Union[LinearOperator, Tensor]] = None, | ||
noise_covar: Optional[Noise] = None, | ||
): | ||
"""Affine probit likelihood `P(f + e > 0)`, where `f = Ax + b` is as an affine | ||
transformation of an `n`-dimensional Gaussian random vector `x` and `e ~ Noise` | ||
is an `m`-dimensional centered, Gaussian noise vector. | ||
Args: | ||
weight: Matrix `A` with shape (... x n x m). | ||
bias: Vector `b` with shape (... x m). | ||
noise_covar: Noise covariance matrix with shape (... x m x m). | ||
""" | ||
Likelihood.__init__(self) | ||
self.weight = weight | ||
self.bias = bias | ||
self.noise_covar = noise_covar | ||
|
||
def get_affine_transform( | ||
self, diag: Optional[Tensor] = None | ||
) -> Tuple[Union[Tensor, LinearOperator], Optional[Union[Tensor, LinearOperator]]]: | ||
"""Returns the base affine transform with sign flips for negative labels. | ||
Args: | ||
diag: Scaling factors `d` for the affine transform such that (DA, Db) is | ||
returned, where `D = diag(d)`. | ||
Returns: | ||
Tensor representation of the affine transform (A, b). | ||
""" | ||
if diag is None: | ||
return self.weight, self.bias | ||
|
||
D = DiagLinearOperator(diag) | ||
return D @ self.weight, None if (self.bias is None) else D @ self.bias | ||
|
||
def marginal( | ||
self, | ||
function_dist: MultivariateNormal, | ||
observations: Optional[BoolTensor] = None, | ||
) -> TruncatedMultivariateNormal: | ||
"""Returns the truncated multivariate normal distribution of `h | h > 0`, where | ||
`x` is a Gaussian random vector, `h = (Ax + b) + e`, and `e ~ Noise`.""" | ||
gauss_loc = function_dist.loc | ||
gauss_cov = function_dist.covariance_matrix | ||
signed_labels = ( | ||
None | ||
if observations is None | ||
else 2 * observations.to(dtype=gauss_loc.dtype, device=gauss_loc.device) - 1 | ||
) | ||
|
||
A, b = self.get_affine_transform(diag=signed_labels) | ||
trunc_loc = A @ gauss_loc if (b is None) else A @ gauss_loc + b | ||
trunc_cov = A @ gauss_cov @ A.transpose(-1, -2) | ||
if self.noise_covar is not None: | ||
noise_diag = self.noise_covar(shape=trunc_cov.shape[:-1]) | ||
trunc_cov = (trunc_cov + noise_diag).to_dense() | ||
|
||
return TruncatedMultivariateNormal( | ||
loc=trunc_loc, | ||
covariance_matrix=trunc_cov, | ||
bounds=pad(torch.full_like(trunc_loc, float("inf")).unsqueeze(-1), (1, 0)), | ||
validate_args=False, | ||
) | ||
|
||
def log_marginal( | ||
self, | ||
observations: BoolTensor, | ||
function_dist: MultivariateNormal, | ||
) -> Tensor: | ||
"""Returns the log marginal likelihood `ln p(y) = ln P([2y - 1](f + e) > 0)`, | ||
where `f = Ax + b` and `e ~ Noise`.""" | ||
return self.marginal(function_dist, observations=observations).log_partition | ||
|
||
def latent_marginal( | ||
self, | ||
function_dist: MultivariateNormal, | ||
observations: Optional[BoolTensor] = None, | ||
) -> UnifiedSkewNormal: | ||
"""Returns the UnifiedSkewNormal distribution of `x | f + e > 0`, where | ||
`x` is a Gaussian random vector, `f = Ax + b`, and `e ~ Noise`.""" | ||
gauss_loc = function_dist.loc | ||
gauss_cov = function_dist.covariance_matrix | ||
signed_labels = ( | ||
None | ||
if observations is None | ||
else 2 * observations.to(dtype=gauss_loc.dtype, device=gauss_loc.device) - 1 | ||
) | ||
|
||
A, b = self.get_affine_transform(diag=signed_labels) | ||
trunc_loc = A @ gauss_loc if (b is None) else A @ gauss_loc + b | ||
cross_cov = A @ gauss_cov | ||
trunc_cov = cross_cov @ A.transpose(-1, -2) | ||
if self.noise_covar is not None: | ||
noise_diag = self.noise_covar(shape=trunc_cov.shape[:-1]) | ||
trunc_cov = (trunc_cov + noise_diag).to_dense() | ||
|
||
trunc = TruncatedMultivariateNormal( | ||
loc=trunc_loc, | ||
covariance_matrix=trunc_cov, | ||
bounds=pad(torch.full_like(trunc_loc, float("inf")).unsqueeze(-1), (1, 0)), | ||
validate_args=False, | ||
) | ||
|
||
return UnifiedSkewNormal( | ||
trunc=trunc, | ||
gauss=function_dist, | ||
cross_covariance_matrix=cross_cov, | ||
) |
Oops, something went wrong.