Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update use cases of deprecated fixed-noise models #2059

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions import UnsupportedError
from botorch.models.gp_regression import FixedNoiseGP
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import Model
from botorch.utils.constants import get_constants_like
Expand All @@ -38,6 +38,7 @@
)
from botorch.utils.safe_math import log1mexp, logmeanexp
from botorch.utils.transforms import convert_to_target_pre_hook, t_batch_mode_transform
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from torch import Tensor
from torch.nn.functional import pad

Expand Down Expand Up @@ -580,11 +581,12 @@ class LogNoisyExpectedImprovement(AnalyticAcquisitionFunction):

where `X_base` are previously observed points.

Note: This acquisition function currently relies on using a FixedNoiseGP (required
for noiseless fantasies).
Note: This acquisition function currently relies on using a SingleTaskGP
with known observation noise. In other words, `train_Yvar` must be passed
to the model. (required for noiseless fantasies).

Example:
>>> model = FixedNoiseGP(train_X, train_Y, train_Yvar=train_Yvar)
>>> model = SingleTaskGP(train_X, train_Y, train_Yvar=train_Yvar)
>>> LogNEI = LogNoisyExpectedImprovement(model, train_X)
>>> nei = LogNEI(test_X)
"""
Expand All @@ -608,10 +610,6 @@ def __init__(
complexity and performance).
maximize: If True, consider the problem a maximization problem.
"""
if not isinstance(model, FixedNoiseGP):
raise UnsupportedError(
"Only FixedNoiseGPs are currently supported for fantasy LogNEI"
)
# sample fantasies
from botorch.sampling.normal import SobolQMCNormalSampler

Expand Down Expand Up @@ -662,11 +660,12 @@ class NoisyExpectedImprovement(ExpectedImprovement):
`NEI(x) = E(max(y - max Y_baseline), 0)), (y, Y_baseline) ~ f((x, X_baseline))`,
where `X_baseline` are previously observed points.

Note: This acquisition function currently relies on using a FixedNoiseGP (required
for noiseless fantasies).
Note: This acquisition function currently relies on using a SingleTaskGP
with known observation noise. In other words, `train_Yvar` must be passed
to the model. (required for noiseless fantasies).

Example:
>>> model = FixedNoiseGP(train_X, train_Y, train_Yvar=train_Yvar)
>>> model = SingleTaskGP(train_X, train_Y, train_Yvar=train_Yvar)
>>> NEI = NoisyExpectedImprovement(model, train_X)
>>> nei = NEI(test_X)
"""
Expand All @@ -689,10 +688,6 @@ def __init__(
complexity and performance).
maximize: If True, consider the problem a maximization problem.
"""
if not isinstance(model, FixedNoiseGP):
raise UnsupportedError(
"Only FixedNoiseGPs are currently supported for fantasy NEI"
)
# sample fantasies
from botorch.sampling.normal import SobolQMCNormalSampler

Expand Down Expand Up @@ -974,15 +969,15 @@ def logerfcx(x: Tensor) -> Tensor:


def _get_noiseless_fantasy_model(
model: FixedNoiseGP, batch_X_observed: Tensor, Y_fantasized: Tensor
) -> FixedNoiseGP:
model: SingleTaskGP, batch_X_observed: Tensor, Y_fantasized: Tensor
) -> SingleTaskGP:
r"""Construct a fantasy model from a fitted model and provided fantasies.

The fantasy model uses the hyperparameters from the original fitted model and
assumes the fantasies are noiseless.

Args:
model: a fitted FixedNoiseGP
model: A fitted SingleTaskGP with known observation noise.
batch_X_observed: A `b x n x d` tensor of inputs where `b` is the number of
fantasies.
Y_fantasized: A `b x n` tensor of fantasized targets where `b` is the number of
Expand All @@ -991,11 +986,18 @@ def _get_noiseless_fantasy_model(
Returns:
The fantasy model.
"""
# initialize a copy of FixedNoiseGP on the original training inputs
# this makes FixedNoiseGP a non-batch GP, so that the same hyperparameters
if not isinstance(model, SingleTaskGP) or not isinstance(
model.likelihood, FixedNoiseGaussianLikelihood
):
raise UnsupportedError(
"Only SingleTaskGP models with known observation noise "
"are currently supported for fantasy-based NEI & LogNEI."
)
# initialize a copy of SingleTaskGP on the original training inputs
# this makes SingleTaskGP a non-batch GP, so that the same hyperparameters
# are used across all batches (by default, a GP with batched training data
# uses independent hyperparameters for each batch).
fantasy_model = FixedNoiseGP(
fantasy_model = SingleTaskGP(
train_X=model.train_inputs[0],
train_Y=model.train_targets.unsqueeze(-1),
train_Yvar=model.likelihood.noise_covar.noise.unsqueeze(-1),
Expand Down
16 changes: 9 additions & 7 deletions botorch/models/contextual.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,29 @@

from typing import Any, Dict, List, Optional

from botorch.models.gp_regression import FixedNoiseGP
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.kernels.contextual_lcea import LCEAKernel
from botorch.models.kernels.contextual_sac import SACKernel
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor


class SACGP(FixedNoiseGP):
class SACGP(SingleTaskGP):
r"""A GP using a Structural Additive Contextual(SAC) kernel."""

def __init__(
self,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor,
train_Yvar: Optional[Tensor],
decomposition: Dict[str, List[int]],
) -> None:
r"""
Args:
train_X: (n x d) X training data.
train_Y: (n x 1) Y training data.
train_Yvar: (n x 1) Noise variances of each training Y.
train_Yvar: (n x 1) Noise variances of each training Y. If None,
we use an inferred noise likelihood.
decomposition: Keys are context names. Values are the indexes of
parameters belong to the context. The parameter indexes are in
the same order across contexts.
Expand Down Expand Up @@ -62,7 +63,7 @@ def construct_inputs(
}


class LCEAGP(FixedNoiseGP):
class LCEAGP(SingleTaskGP):
r"""A GP using a Latent Context Embedding Additive (LCE-A) Kernel.

Note that the model does not support batch training. Input training
Expand All @@ -73,7 +74,7 @@ def __init__(
self,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor,
train_Yvar: Optional[Tensor],
decomposition: Dict[str, List[int]],
train_embedding: bool = True,
cat_feature_dict: Optional[Dict] = None,
Expand All @@ -85,7 +86,8 @@ def __init__(
Args:
train_X: (n x d) X training data.
train_Y: (n x 1) Y training data.
train_Yvar: (n x 1) Noise variance of Y.
train_Yvar: (n x 1) Noise variance of Y. If None,
we use an inferred noise likelihood.
decomposition: Keys are context names. Values are the indexes of
parameters belong to the context.
train_embedding: Whether to train the embedding layer or not. If False,
Expand Down
1 change: 1 addition & 0 deletions botorch/models/contextual_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class FixedNoiseLCEMGP(LCEMGP):
(LCE-M) kernel, with known observation noise.

DEPRECATED: Please use `LCEMGP` with `train_Yvar` instead.
Will be removed in a future release (~v0.11).
"""

def __init__(
Expand Down
9 changes: 5 additions & 4 deletions botorch/models/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@

import torch
from botorch.exceptions import UnsupportedError
from botorch.models.gp_regression import FixedNoiseGP, HeteroskedasticSingleTaskGP
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from torch import Tensor
from torch.nn import Module

Expand Down Expand Up @@ -140,7 +141,7 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
train_X = deepcopy(models[0].train_inputs[0])
train_Y = torch.stack([m.train_targets.clone() for m in models], dim=-1)
kwargs = {"train_X": train_X, "train_Y": train_Y}
if isinstance(models[0], FixedNoiseGP):
if isinstance(models[0].likelihood, FixedNoiseGaussianLikelihood):
kwargs["train_Yvar"] = torch.stack(
[m.likelihood.noise_covar.noise.clone() for m in models], dim=-1
)
Expand Down Expand Up @@ -302,7 +303,7 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
.clone()
.unsqueeze(-1),
}
if isinstance(batch_model, FixedNoiseGP):
if isinstance(batch_model.likelihood, FixedNoiseGaussianLikelihood):
noise_covar = batch_model.likelihood.noise_covar
kwargs["train_Yvar"] = (
noise_covar.noise.select(input_bdims, i).clone().unsqueeze(-1)
Expand Down Expand Up @@ -390,7 +391,7 @@ def batched_multi_output_to_single_output(
"train_X": batch_mo_model.train_inputs[0].clone(),
"train_Y": batch_mo_model.train_targets.clone().unsqueeze(-1),
}
if isinstance(batch_mo_model, FixedNoiseGP):
if isinstance(batch_mo_model.likelihood, FixedNoiseGaussianLikelihood):
noise_covar = batch_mo_model.likelihood.noise_covar
kwargs["train_Yvar"] = noise_covar.noise.clone().unsqueeze(-1)
if isinstance(batch_mo_model, SingleTaskMultiFidelityGP):
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
>>> ])
>>> train_Y = torch.cat(f1(X1), f2(X2)).unsqueeze(-1)
>>> train_Yvar = 0.01 * torch.ones_like(train_Y)
>>> mtsaas_gp = SaasFullyBayesianFixedNoiseMultiTaskGP(
>>> mtsaas_gp = SaasFullyBayesianMultiTaskGP(
>>> train_X, train_Y, train_Yvar, task_feature=-1,
>>> )
>>> fit_fully_bayesian_model_nuts(mtsaas_gp)
Expand Down
Loading