diff --git a/botorch/acquisition/analytic.py b/botorch/acquisition/analytic.py index 2ad538acc2..25d808457c 100644 --- a/botorch/acquisition/analytic.py +++ b/botorch/acquisition/analytic.py @@ -24,7 +24,7 @@ from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.objective import PosteriorTransform from botorch.exceptions import UnsupportedError -from botorch.models.gp_regression import FixedNoiseGP +from botorch.models.gp_regression import SingleTaskGP from botorch.models.gpytorch import GPyTorchModel from botorch.models.model import Model from botorch.utils.constants import get_constants_like @@ -38,6 +38,7 @@ ) from botorch.utils.safe_math import log1mexp, logmeanexp from botorch.utils.transforms import convert_to_target_pre_hook, t_batch_mode_transform +from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from torch import Tensor from torch.nn.functional import pad @@ -580,11 +581,12 @@ class LogNoisyExpectedImprovement(AnalyticAcquisitionFunction): where `X_base` are previously observed points. - Note: This acquisition function currently relies on using a FixedNoiseGP (required - for noiseless fantasies). + Note: This acquisition function currently relies on using a SingleTaskGP + with known observation noise. In other words, `train_Yvar` must be passed + to the model. (required for noiseless fantasies). Example: - >>> model = FixedNoiseGP(train_X, train_Y, train_Yvar=train_Yvar) + >>> model = SingleTaskGP(train_X, train_Y, train_Yvar=train_Yvar) >>> LogNEI = LogNoisyExpectedImprovement(model, train_X) >>> nei = LogNEI(test_X) """ @@ -608,10 +610,6 @@ def __init__( complexity and performance). maximize: If True, consider the problem a maximization problem. """ - if not isinstance(model, FixedNoiseGP): - raise UnsupportedError( - "Only FixedNoiseGPs are currently supported for fantasy LogNEI" - ) # sample fantasies from botorch.sampling.normal import SobolQMCNormalSampler @@ -662,11 +660,12 @@ class NoisyExpectedImprovement(ExpectedImprovement): `NEI(x) = E(max(y - max Y_baseline), 0)), (y, Y_baseline) ~ f((x, X_baseline))`, where `X_baseline` are previously observed points. - Note: This acquisition function currently relies on using a FixedNoiseGP (required - for noiseless fantasies). + Note: This acquisition function currently relies on using a SingleTaskGP + with known observation noise. In other words, `train_Yvar` must be passed + to the model. (required for noiseless fantasies). Example: - >>> model = FixedNoiseGP(train_X, train_Y, train_Yvar=train_Yvar) + >>> model = SingleTaskGP(train_X, train_Y, train_Yvar=train_Yvar) >>> NEI = NoisyExpectedImprovement(model, train_X) >>> nei = NEI(test_X) """ @@ -689,10 +688,6 @@ def __init__( complexity and performance). maximize: If True, consider the problem a maximization problem. """ - if not isinstance(model, FixedNoiseGP): - raise UnsupportedError( - "Only FixedNoiseGPs are currently supported for fantasy NEI" - ) # sample fantasies from botorch.sampling.normal import SobolQMCNormalSampler @@ -974,15 +969,15 @@ def logerfcx(x: Tensor) -> Tensor: def _get_noiseless_fantasy_model( - model: FixedNoiseGP, batch_X_observed: Tensor, Y_fantasized: Tensor -) -> FixedNoiseGP: + model: SingleTaskGP, batch_X_observed: Tensor, Y_fantasized: Tensor +) -> SingleTaskGP: r"""Construct a fantasy model from a fitted model and provided fantasies. The fantasy model uses the hyperparameters from the original fitted model and assumes the fantasies are noiseless. Args: - model: a fitted FixedNoiseGP + model: A fitted SingleTaskGP with known observation noise. batch_X_observed: A `b x n x d` tensor of inputs where `b` is the number of fantasies. Y_fantasized: A `b x n` tensor of fantasized targets where `b` is the number of @@ -991,11 +986,18 @@ def _get_noiseless_fantasy_model( Returns: The fantasy model. """ - # initialize a copy of FixedNoiseGP on the original training inputs - # this makes FixedNoiseGP a non-batch GP, so that the same hyperparameters + if not isinstance(model, SingleTaskGP) or not isinstance( + model.likelihood, FixedNoiseGaussianLikelihood + ): + raise UnsupportedError( + "Only SingleTaskGP models with known observation noise " + "are currently supported for fantasy-based NEI & LogNEI." + ) + # initialize a copy of SingleTaskGP on the original training inputs + # this makes SingleTaskGP a non-batch GP, so that the same hyperparameters # are used across all batches (by default, a GP with batched training data # uses independent hyperparameters for each batch). - fantasy_model = FixedNoiseGP( + fantasy_model = SingleTaskGP( train_X=model.train_inputs[0], train_Y=model.train_targets.unsqueeze(-1), train_Yvar=model.likelihood.noise_covar.noise.unsqueeze(-1), diff --git a/botorch/models/contextual.py b/botorch/models/contextual.py index 0e99c4336c..5c9993bf93 100644 --- a/botorch/models/contextual.py +++ b/botorch/models/contextual.py @@ -6,28 +6,29 @@ from typing import Any, Dict, List, Optional -from botorch.models.gp_regression import FixedNoiseGP +from botorch.models.gp_regression import SingleTaskGP from botorch.models.kernels.contextual_lcea import LCEAKernel from botorch.models.kernels.contextual_sac import SACKernel from botorch.utils.datasets import SupervisedDataset from torch import Tensor -class SACGP(FixedNoiseGP): +class SACGP(SingleTaskGP): r"""A GP using a Structural Additive Contextual(SAC) kernel.""" def __init__( self, train_X: Tensor, train_Y: Tensor, - train_Yvar: Tensor, + train_Yvar: Optional[Tensor], decomposition: Dict[str, List[int]], ) -> None: r""" Args: train_X: (n x d) X training data. train_Y: (n x 1) Y training data. - train_Yvar: (n x 1) Noise variances of each training Y. + train_Yvar: (n x 1) Noise variances of each training Y. If None, + we use an inferred noise likelihood. decomposition: Keys are context names. Values are the indexes of parameters belong to the context. The parameter indexes are in the same order across contexts. @@ -62,7 +63,7 @@ def construct_inputs( } -class LCEAGP(FixedNoiseGP): +class LCEAGP(SingleTaskGP): r"""A GP using a Latent Context Embedding Additive (LCE-A) Kernel. Note that the model does not support batch training. Input training @@ -73,7 +74,7 @@ def __init__( self, train_X: Tensor, train_Y: Tensor, - train_Yvar: Tensor, + train_Yvar: Optional[Tensor], decomposition: Dict[str, List[int]], train_embedding: bool = True, cat_feature_dict: Optional[Dict] = None, @@ -85,7 +86,8 @@ def __init__( Args: train_X: (n x d) X training data. train_Y: (n x 1) Y training data. - train_Yvar: (n x 1) Noise variance of Y. + train_Yvar: (n x 1) Noise variance of Y. If None, + we use an inferred noise likelihood. decomposition: Keys are context names. Values are the indexes of parameters belong to the context. train_embedding: Whether to train the embedding layer or not. If False, diff --git a/botorch/models/contextual_multioutput.py b/botorch/models/contextual_multioutput.py index 27870e18e6..0581acd104 100644 --- a/botorch/models/contextual_multioutput.py +++ b/botorch/models/contextual_multioutput.py @@ -170,6 +170,7 @@ class FixedNoiseLCEMGP(LCEMGP): (LCE-M) kernel, with known observation noise. DEPRECATED: Please use `LCEMGP` with `train_Yvar` instead. + Will be removed in a future release (~v0.11). """ def __init__( diff --git a/botorch/models/converter.py b/botorch/models/converter.py index 04e292bf3d..d87b8a237d 100644 --- a/botorch/models/converter.py +++ b/botorch/models/converter.py @@ -15,13 +15,14 @@ import torch from botorch.exceptions import UnsupportedError -from botorch.models.gp_regression import FixedNoiseGP, HeteroskedasticSingleTaskGP +from botorch.models.gp_regression import HeteroskedasticSingleTaskGP from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.transforms.input import InputTransform from botorch.models.transforms.outcome import OutcomeTransform +from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from torch import Tensor from torch.nn import Module @@ -140,7 +141,7 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch train_X = deepcopy(models[0].train_inputs[0]) train_Y = torch.stack([m.train_targets.clone() for m in models], dim=-1) kwargs = {"train_X": train_X, "train_Y": train_Y} - if isinstance(models[0], FixedNoiseGP): + if isinstance(models[0].likelihood, FixedNoiseGaussianLikelihood): kwargs["train_Yvar"] = torch.stack( [m.likelihood.noise_covar.noise.clone() for m in models], dim=-1 ) @@ -302,7 +303,7 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model .clone() .unsqueeze(-1), } - if isinstance(batch_model, FixedNoiseGP): + if isinstance(batch_model.likelihood, FixedNoiseGaussianLikelihood): noise_covar = batch_model.likelihood.noise_covar kwargs["train_Yvar"] = ( noise_covar.noise.select(input_bdims, i).clone().unsqueeze(-1) @@ -390,7 +391,7 @@ def batched_multi_output_to_single_output( "train_X": batch_mo_model.train_inputs[0].clone(), "train_Y": batch_mo_model.train_targets.clone().unsqueeze(-1), } - if isinstance(batch_mo_model, FixedNoiseGP): + if isinstance(batch_mo_model.likelihood, FixedNoiseGaussianLikelihood): noise_covar = batch_mo_model.likelihood.noise_covar kwargs["train_Yvar"] = noise_covar.noise.clone().unsqueeze(-1) if isinstance(batch_mo_model, SingleTaskMultiFidelityGP): diff --git a/botorch/models/fully_bayesian_multitask.py b/botorch/models/fully_bayesian_multitask.py index 6a7b7220ce..e8d2182404 100644 --- a/botorch/models/fully_bayesian_multitask.py +++ b/botorch/models/fully_bayesian_multitask.py @@ -182,7 +182,7 @@ class SaasFullyBayesianMultiTaskGP(MultiTaskGP): >>> ]) >>> train_Y = torch.cat(f1(X1), f2(X2)).unsqueeze(-1) >>> train_Yvar = 0.01 * torch.ones_like(train_Y) - >>> mtsaas_gp = SaasFullyBayesianFixedNoiseMultiTaskGP( + >>> mtsaas_gp = SaasFullyBayesianMultiTaskGP( >>> train_X, train_Y, train_Yvar, task_feature=-1, >>> ) >>> fit_fully_bayesian_model_nuts(mtsaas_gp) diff --git a/botorch/models/gp_regression.py b/botorch/models/gp_regression.py index f69b50f08e..43a7c8e50f 100644 --- a/botorch/models/gp_regression.py +++ b/botorch/models/gp_regression.py @@ -30,7 +30,8 @@ from __future__ import annotations -from typing import Any, List, NoReturn, Optional +import warnings +from typing import NoReturn, Optional import torch from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel @@ -43,7 +44,6 @@ get_matern_kernel_with_gamma_prior, MIN_INFERRED_NOISE_LEVEL, ) -from botorch.sampling.base import MCSampler from gpytorch.constraints.constraints import GreaterThan from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.likelihoods.gaussian_likelihood import ( @@ -63,7 +63,7 @@ class SingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin): - r"""A single-task exact GP model. + r"""A single-task exact GP model, supporting both known and inferred noise levels. A single-task exact GP using relatively strong priors on the Kernel hyperparameters, which work best when covariates are normalized to the unit @@ -78,16 +78,35 @@ class SingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin): training data, use the ModelListGP. When modeling correlations between outputs, use the MultiTaskGP. + An example of a case in which noise levels are known is online + experimentation, where noise can be measured using the variability of + different observations from the same arm, or provided by outside software. + Another use case is simulation optimization, where the evaluation can + provide variance estimates, perhaps from bootstrapping. In any case, these + noise levels can be provided to `SingleTaskGP` as `train_Yvar`. + + `SingleTaskGP` can also be used when the observations are known to be + noise-free. Noise-free observations can be modeled using arbitrarily small + noise values, such as `train_Yvar=torch.full_like(train_Y, 1e-6)`. + Example: + >>> # Model with inferred noise levels. >>> train_X = torch.rand(20, 2) >>> train_Y = torch.sin(train_X).sum(dim=1, keepdim=True) - >>> model = SingleTaskGP(train_X, train_Y) + >>> inferred_noise_model = SingleTaskGP(train_X, train_Y) + >>> # With known observation variance of 0.2. + >>> train_Yvar = torch.full_like(train_Y, 0.2) + >>> observed_noise_model = SingleTaskGP(train_X, train_Y, train_Yvar) + >>> # To model noise-free observations. + >>> train_Yvar = torch.full_like(train_Y, 1e-6) + >>> noise_free_model = SingleTaskGP(train_X, train_Y, train_Yvar) """ def __init__( self, train_X: Tensor, train_Y: Tensor, + train_Yvar: Optional[Tensor] = None, likelihood: Optional[Likelihood] = None, covar_module: Optional[Module] = None, mean_module: Optional[Mean] = None, @@ -98,8 +117,12 @@ def __init__( Args: train_X: A `batch_shape x n x d` tensor of training features. train_Y: A `batch_shape x n x m` tensor of training observations. + train_Yvar: An optional `batch_shape x n x m` tensor of observed + measurement noise. likelihood: A likelihood. If omitted, use a standard - GaussianLikelihood with inferred noise level. + `GaussianLikelihood` with inferred noise level if `train_Yvar` + is None, and a `FixedNoiseGaussianLikelihood` with the given + noise observations if `train_Yvar` is not None. covar_module: The module computing the covariance (Kernel) matrix. If omitted, use a `MaternKernel`. mean_module: The mean function to be used. If omitted, use a @@ -116,18 +139,28 @@ def __init__( X=train_X, input_transform=input_transform ) if outcome_transform is not None: - train_Y, _ = outcome_transform(train_Y) - self._validate_tensor_args(X=transformed_X, Y=train_Y) + train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar) + self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar) ignore_X_dims = getattr(self, "_ignore_X_dims_scaling_check", None) validate_input_scaling( - train_X=transformed_X, train_Y=train_Y, ignore_X_dims=ignore_X_dims + train_X=transformed_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + ignore_X_dims=ignore_X_dims, ) self._set_dimensions(train_X=train_X, train_Y=train_Y) - train_X, train_Y, _ = self._transform_tensor_args(X=train_X, Y=train_Y) + train_X, train_Y, train_Yvar = self._transform_tensor_args( + X=train_X, Y=train_Y, Yvar=train_Yvar + ) if likelihood is None: - likelihood = get_gaussian_likelihood_with_gamma_prior( - batch_shape=self._aug_batch_shape - ) + if train_Yvar is None: + likelihood = get_gaussian_likelihood_with_gamma_prior( + batch_shape=self._aug_batch_shape + ) + else: + likelihood = FixedNoiseGaussianLikelihood( + noise=train_Yvar, batch_shape=self._aug_batch_shape + ) else: self._is_custom_likelihood = True ExactGP.__init__( @@ -142,11 +175,12 @@ def __init__( batch_shape=self._aug_batch_shape, ) self._subset_batch_dict = { - "likelihood.noise_covar.raw_noise": -2, "mean_module.raw_constant": -1, "covar_module.raw_outputscale": -1, "covar_module.base_kernel.raw_lengthscale": -3, } + if train_Yvar is None: + self._subset_batch_dict["likelihood.noise_covar.raw_noise"] = -2 self.covar_module = covar_module # TODO: Allow subsetting of other covar modules if outcome_transform is not None: @@ -163,37 +197,12 @@ def forward(self, x: Tensor) -> MultivariateNormal: return MultivariateNormal(mean_x, covar_x) -class FixedNoiseGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin): +class FixedNoiseGP(SingleTaskGP): r"""A single-task exact GP model using fixed noise levels. - A single-task exact GP that uses fixed observation noise levels, differing from - `SingleTaskGP` only in that noise levels are provided rather than inferred. - This model also uses relatively strong priors on the Kernel hyperparameters, - which work best when covariates are normalized to the unit cube and outcomes - are standardized (zero mean, unit variance). - - This model works in batch mode (each batch having its own hyperparameters). - - An example of a case in which noise levels are known is online - experimentation, where noise can be measured using the variability of - different observations from the same arm, or provided by outside software. - Another use case is simulation optimization, where the evaluation can - provide variance estimates, perhaps from bootstrapping. In any case, these - noise levels must be provided to `FixedNoiseGP` as `train_Yvar`. - - `FixedNoiseGP` is also commonly used when the observations are known to be - noise-free. Noise-free observations can be modeled using arbitrarily small - noise values, such as `train_Yvar=torch.full_like(train_Y, 1e-6)`. - - `FixedNoiseGP` cannot predict noise levels out of sample. If this is needed, - use `HeteroskedasticSingleTaskGP`, which will create another model for the - observation noise. - - Example: - >>> train_X = torch.rand(20, 2) - >>> train_Y = torch.sin(train_X).sum(dim=1, keepdim=True) - >>> train_Yvar = torch.full_like(train_Y, 0.2) - >>> model = FixedNoiseGP(train_X, train_Y, train_Yvar) + DEPRECATED: `FixedNoiseGP` has been merged into `SingleTaskGP`. Please use + `SingleTaskGP` with `train_Yvar` instead. + Will be removed in a future release (~v0.12). """ def __init__( @@ -206,140 +215,22 @@ def __init__( outcome_transform: Optional[OutcomeTransform] = None, input_transform: Optional[InputTransform] = None, ) -> None: - r""" - Args: - train_X: A `batch_shape x n x d` tensor of training features. - train_Y: A `batch_shape x n x m` tensor of training observations. - train_Yvar: A `batch_shape x n x m` tensor of observed measurement - noise. - covar_module: The module computing the covariance (Kernel) matrix. - If omitted, use a `MaternKernel`. - mean_module: The mean function to be used. If omitted, use a - `ConstantMean`. - outcome_transform: An outcome transform that is applied to the - training data during instantiation and to the posterior during - inference (that is, the `Posterior` obtained by calling - `.posterior` on the model will be on the original scale). - input_transform: An input transfrom that is applied in the model's - forward pass. - """ - with torch.no_grad(): - transformed_X = self.transform_inputs( - X=train_X, input_transform=input_transform - ) - if outcome_transform is not None: - train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar) - self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar) - validate_input_scaling( - train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar - ) - self._set_dimensions(train_X=train_X, train_Y=train_Y) - train_X, train_Y, train_Yvar = self._transform_tensor_args( - X=train_X, Y=train_Y, Yvar=train_Yvar - ) - likelihood = FixedNoiseGaussianLikelihood( - noise=train_Yvar, batch_shape=self._aug_batch_shape + r"""DEPRECATED. See SingleTaskGP.""" + warnings.warn( + "`FixedNoiseGP` has been merged into `SingleTaskGP`. " + "Please use `SingleTaskGP` with `train_Yvar` instead.", + DeprecationWarning, ) - ExactGP.__init__( - self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood - ) - if mean_module is None: - mean_module = ConstantMean(batch_shape=self._aug_batch_shape) - self.mean_module = mean_module - if covar_module is None: - covar_module = get_matern_kernel_with_gamma_prior( - ard_num_dims=transformed_X.shape[-1], - batch_shape=self._aug_batch_shape, - ) - self._subset_batch_dict = { - "mean_module.raw_constant": -1, - "covar_module.raw_outputscale": -1, - "covar_module.base_kernel.raw_lengthscale": -3, - } - self.covar_module = covar_module - # TODO: Allow subsetting of other covar modules - if input_transform is not None: - self.input_transform = input_transform - if outcome_transform is not None: - self.outcome_transform = outcome_transform - - self.to(train_X) - - def fantasize( - self, - X: Tensor, - sampler: MCSampler, - observation_noise: Optional[Tensor] = None, - **kwargs: Any, - ) -> FixedNoiseGP: - r"""Construct a fantasy model. - - Constructs a fantasy model in the following fashion: - (1) compute the model posterior at `X` (if `observation_noise=True`, - this includes observation noise taken as the mean across the observation - noise in the training data. If `observation_noise` is a Tensor, use - it directly as the observation noise to add). - (2) sample from this posterior (using `sampler`) to generate "fake" - observations. - (3) condition the model on the new fake observations. - - Args: - X: A `batch_shape x n' x d`-dim Tensor, where `d` is the dimension of - the feature space, `n'` is the number of points per batch, and - `batch_shape` is the batch shape (must be compatible with the - batch shape of the model). - sampler: The sampler used for sampling from the posterior at `X`. - observation_noise: The noise level for fantasization if - provided. If `None`, the mean across the observation - noise in the training data is used as observation noise in - the posterior from which the samples are drawn and - the fantasized noise level. If observation noise is - provided, it is assumed to be in the outcome-transformed - space, if an outcome transform is used. - - Returns: - The constructed fantasy model. - """ - # self.likelihood.noise is an `batch_shape x n x s(m)`-dimensional tensor - if observation_noise is None: - if self.num_outputs > 1: - # make noise ... x n x m - observation_noise = self.likelihood.noise.transpose(-1, -2) - else: - observation_noise = self.likelihood.noise.unsqueeze(-1) - observation_noise = observation_noise.mean(dim=-2, keepdim=True) - - return super().fantasize( - X=X, - sampler=sampler, - observation_noise=observation_noise, - **kwargs, + super().__init__( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + covar_module=covar_module, + mean_module=mean_module, + outcome_transform=outcome_transform, + input_transform=input_transform, ) - def forward(self, x: Tensor) -> MultivariateNormal: - # TODO: reduce redundancy with the 'forward' method of - # SingleTaskGP, which is identical - if self.training: - x = self.transform_inputs(x) - mean_x = self.mean_module(x) - covar_x = self.covar_module(x) - return MultivariateNormal(mean_x, covar_x) - - def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel: - r"""Subset the model along the output dimension. - - Args: - idcs: The output indices to subset the model to. - - Returns: - The current model, subset to the specified output indices. - """ - new_model = super().subset_output(idcs=idcs) - full_noise = new_model.likelihood.noise_covar.noise - new_noise = full_noise[..., idcs if len(idcs) > 1 else idcs[0], :] - new_model.likelihood.noise_covar.noise = new_noise - return new_model - class HeteroskedasticSingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP): r"""A single-task exact GP model using a heteroskedastic noise model. diff --git a/botorch/models/gp_regression_fidelity.py b/botorch/models/gp_regression_fidelity.py index 2ff0a64cd5..88091a998d 100644 --- a/botorch/models/gp_regression_fidelity.py +++ b/botorch/models/gp_regression_fidelity.py @@ -31,7 +31,7 @@ import torch from botorch.exceptions.errors import UnsupportedError -from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP +from botorch.models.gp_regression import SingleTaskGP from botorch.models.kernels.downsampling import DownsamplingKernel from botorch.models.kernels.exponential_decay import ExponentialDecayKernel from botorch.models.kernels.linear_truncated_fidelity import ( @@ -67,6 +67,7 @@ def __init__( self, train_X: Tensor, train_Y: Tensor, + train_Yvar: Optional[Tensor] = None, iteration_fidelity: Optional[int] = None, data_fidelities: Optional[Union[List[int], Tuple[int]]] = None, data_fidelity: Optional[int] = None, @@ -82,6 +83,8 @@ def __init__( where `s` is the dimension of the fidelity parameters (either one or two). train_Y: A `batch_shape x n x m` tensor of training observations. + train_Yvar: An optional `batch_shape x n x m` tensor of observed + measurement noise. iteration_fidelity: The column index for the training iteration fidelity parameter (optional). data_fidelities: The column indices for the downsampling fidelity parameter. @@ -142,17 +145,19 @@ def __init__( super().__init__( train_X=train_X, train_Y=train_Y, + train_Yvar=train_Yvar, likelihood=likelihood, covar_module=covar_module, outcome_transform=outcome_transform, input_transform=input_transform, ) self._subset_batch_dict = { - "likelihood.noise_covar.raw_noise": -2, "mean_module.raw_constant": -1, "covar_module.raw_outputscale": -1, **subset_batch_dict, } + if train_Yvar is None: + self._subset_batch_dict["likelihood.noise_covar.raw_noise"] = -2 self.to(train_X) @classmethod @@ -173,27 +178,7 @@ def construct_inputs( return inputs -class FixedNoiseMultiFidelityGP(FixedNoiseGP): - r"""A single task multi-fidelity GP model using fixed noise levels. - - A FixedNoiseGP model analogue to SingleTaskMultiFidelityGP, using a - DownsamplingKernel for the data fidelity parameter (if present) and - an ExponentialDecayKernel for the iteration fidelity parameter (if present). - - This kernel is described in [Wu2019mf]_. - - Example: - >>> train_X = torch.rand(20, 4) - >>> train_Y = train_X.pow(2).sum(dim=-1, keepdim=True) - >>> train_Yvar = torch.full_like(train_Y) * 0.01 - >>> model = FixedNoiseMultiFidelityGP( - >>> train_X, - >>> train_Y, - >>> train_Yvar, - >>> data_fidelities=[3], - >>> ) - """ - +class FixedNoiseMultiFidelityGP(SingleTaskMultiFidelityGP): def __init__( self, train_X: Tensor, @@ -207,99 +192,26 @@ def __init__( outcome_transform: Optional[OutcomeTransform] = None, input_transform: Optional[InputTransform] = None, ) -> None: - r""" - Args: - train_X: A `batch_shape x n x (d + s)` tensor of training features, - where `s` is the dimension of the fidelity parameters (either one - or two). - train_Y: A `batch_shape x n x m` tensor of training observations. - train_Yvar: A `batch_shape x n x m` tensor of observed measurement noise. - iteration_fidelity: The column index for the training iteration fidelity - parameter (optional). - data_fidelities: The column indices for the downsampling fidelity parameter. - If a list of indices is provided, a kernel will be constructed for - each index (optional). - data_fidelity: The column index for the downsampling fidelity parameter - (optional). Deprecated in favor of `data_fidelities`. - linear_truncated: If True, use a `LinearTruncatedFidelityKernel` instead - of the default kernel. - nu: The smoothness parameter for the Matern kernel: either 1/2, 3/2, or - 5/2. Only used when `linear_truncated=True`. - outcome_transform: An outcome transform that is applied to the - training data during instantiation and to the posterior during - inference (that is, the `Posterior` obtained by calling - `.posterior` on the model will be on the original scale). - input_transform: An input transform that is applied in the model's - forward pass. + r"""DEPRECATED: Use `SingleTaskMultiFidelityGP` instead. + Will be removed in a future release (~v0.11). """ - if data_fidelity is not None: - warnings.warn( - "The `data_fidelity` argument is deprecated and will be removed in " - "a future release. Please use `data_fidelities` instead.", - DeprecationWarning, - ) - if data_fidelities is not None: - raise ValueError( - "Cannot specify both `data_fidelity` and `data_fidelities`." - ) - data_fidelities = [data_fidelity] - - self._init_args = { - "iteration_fidelity": iteration_fidelity, - "data_fidelities": data_fidelities, - "linear_truncated": linear_truncated, - "nu": nu, - "outcome_transform": outcome_transform, - } - if iteration_fidelity is None and data_fidelities is None: - raise UnsupportedError( - "FixedNoiseMultiFidelityGP requires at least one fidelity parameter." - ) - with torch.no_grad(): - transformed_X = self.transform_inputs( - X=train_X, input_transform=input_transform - ) - self._set_dimensions(train_X=transformed_X, train_Y=train_Y) - covar_module, subset_batch_dict = _setup_multifidelity_covar_module( - dim=transformed_X.size(-1), - aug_batch_shape=self._aug_batch_shape, - iteration_fidelity=iteration_fidelity, - data_fidelities=data_fidelities, - linear_truncated=linear_truncated, - nu=nu, + warnings.warn( + "`FixedNoiseMultiFidelityGP` has been deprecated. " + "Use `SingleTaskMultiFidelityGP` with `train_Yvar` instead.", + DeprecationWarning, ) super().__init__( train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar, - covar_module=covar_module, + iteration_fidelity=iteration_fidelity, + data_fidelities=data_fidelities, + data_fidelity=data_fidelity, + linear_truncated=linear_truncated, + nu=nu, outcome_transform=outcome_transform, input_transform=input_transform, ) - self._subset_batch_dict = { - "likelihood.noise_covar.raw_noise": -2, - "mean_module.raw_constant": -1, - "covar_module.raw_outputscale": -1, - **subset_batch_dict, - } - self.to(train_X) - - @classmethod - def construct_inputs( - cls, - training_data: SupervisedDataset, - fidelity_features: List[int], - **kwargs, - ) -> Dict[str, Any]: - r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`. - - Args: - training_data: Dictionary of `SupervisedDataset`. - fidelity_features: Column indices of fidelity features. - """ - inputs = super().construct_inputs(training_data=training_data, **kwargs) - inputs["data_fidelities"] = fidelity_features - return inputs def _setup_multifidelity_covar_module( diff --git a/botorch/models/gp_regression_mixed.py b/botorch/models/gp_regression_mixed.py index 903cf7545b..fb2da32958 100644 --- a/botorch/models/gp_regression_mixed.py +++ b/botorch/models/gp_regression_mixed.py @@ -6,11 +6,9 @@ from __future__ import annotations -import warnings from typing import Any, Callable, Dict, List, Optional import torch -from botorch.exceptions.warnings import InputDataWarning from botorch.models.gp_regression import SingleTaskGP from botorch.models.kernels.categorical import CategoricalKernel from botorch.models.transforms.input import InputTransform @@ -64,6 +62,7 @@ def __init__( train_X: Tensor, train_Y: Tensor, cat_dims: List[int], + train_Yvar: Optional[Tensor] = None, cont_kernel_factory: Optional[ Callable[[torch.Size, int, List[int]], Kernel] ] = None, @@ -78,6 +77,8 @@ def __init__( train_Y: A `batch_shape x n x m` tensor of training observations. cat_dims: A list of indices corresponding to the columns of the input `X` that should be considered categorical features. + train_Yvar: An optional `batch_shape x n x m` tensor of observed + measurement noise. cont_kernel_factory: A method that accepts `batch_shape`, `ard_num_dims`, and `active_dims` arguments and returns an instantiated GPyTorch `Kernel` object to be used as the base kernel for the continuous @@ -118,7 +119,7 @@ def cont_kernel_factory( lengthscale_constraint=GreaterThan(1e-04), ) - if likelihood is None: + if likelihood is None and train_Yvar is None: # This Gamma prior is quite close to the Horseshoe prior min_noise = 1e-5 if train_X.dtype == torch.float else 1e-6 likelihood = GaussianLikelihood( @@ -173,6 +174,7 @@ def cont_kernel_factory( super().__init__( train_X=train_X, train_Y=train_Y, + train_Yvar=train_Yvar, likelihood=likelihood, covar_module=covar_module, outcome_transform=outcome_transform, @@ -195,13 +197,6 @@ def construct_inputs( likelihood: Optional likelihood used to constuct the model. """ base_inputs = super().construct_inputs(training_data=training_data, **kwargs) - if base_inputs.pop("train_Yvar", None) is not None: - # TODO: Remove when SingleTaskGP supports optional Yvar [T162925473]. - warnings.warn( - "`MixedSingleTaskGP` only supports inferred noise at the moment. " - "Ignoring the provided `train_Yvar` observations.", - InputDataWarning, - ) return { **base_inputs, "cat_dims": categorical_features, diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 9cab0f20d8..1863a5a01b 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -540,6 +540,12 @@ def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel: except AttributeError: pass + # Subset fixed noise likelihood if present. + if isinstance(self.likelihood, FixedNoiseGaussianLikelihood): + full_noise = new_model.likelihood.noise_covar.noise + new_noise = full_noise[..., idcs if len(idcs) > 1 else idcs[0], :] + new_model.likelihood.noise_covar.noise = new_noise + return new_model diff --git a/botorch/models/model.py b/botorch/models/model.py index c5d31afc44..fd720e324a 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -45,6 +45,7 @@ from botorch.sampling.list_sampler import ListSampler from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import is_fully_bayesian +from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from torch import Tensor from torch.nn import Module, ModuleDict, ModuleList @@ -325,8 +326,9 @@ def fantasize( r"""Construct a fantasy model. Constructs a fantasy model in the following fashion: - (1) compute the model posterior at `X` (including observation noise if - `observation_noise=True`). + (1) compute the model posterior at `X`, including observation noise. + If `observation_noise` is a Tensor, use it directly as the observation + noise to add. (2) sample from this posterior (using `sampler`) to generate "fake" observations. (3) condition the model on the new fake observations. @@ -341,8 +343,10 @@ def fantasize( a `model_batch_shape x n' x m`-dim tensor containing the average noise for each batch and output, where `m` is the number of outputs. `noise` must be in the outcome-transformed space if an outcome - transform is used. If None, then the noise will be the inferred - noise level. + transform is used. + If None and using an inferred noise likelihood, the noise will be the + inferred noise level. If using a fixed noise likelihood, the mean across + the observation noise in the training data is used as observation noise. kwargs: Will be passed to `model.condition_on_observations` Returns: @@ -352,6 +356,15 @@ def fantasize( raise DeprecationError( "`fantasize` no longer accepts a boolean for `observation_noise`." ) + elif observation_noise is None and isinstance( + self.likelihood, FixedNoiseGaussianLikelihood + ): + if self.num_outputs > 1: + # make noise ... x n x m + observation_noise = self.likelihood.noise.transpose(-1, -2) + else: + observation_noise = self.likelihood.noise.unsqueeze(-1) + observation_noise = observation_noise.mean(dim=-2, keepdim=True) # if the inputs are empty, expand the inputs if X.shape[-2] == 0: output_shape = ( @@ -360,9 +373,10 @@ def fantasize( + self.batch_shape + torch.Size([0, self.num_outputs]) ) + Y = torch.empty(output_shape, dtype=X.dtype, device=X.device) return self.condition_on_observations( X=self.transform_inputs(X), - Y=torch.empty(output_shape, dtype=X.dtype, device=X.device), + Y=Y, **kwargs, ) propagate_grads = kwargs.pop("propagate_grads", False) diff --git a/botorch/models/multitask.py b/botorch/models/multitask.py index 786a495d55..03b0668616 100644 --- a/botorch/models/multitask.py +++ b/botorch/models/multitask.py @@ -327,6 +327,7 @@ class FixedNoiseMultiTaskGP(MultiTaskGP): r"""Multi-Task GP model using an ICM kernel, with known observation noise. DEPRECATED: Please use `MultiTaskGP` with `train_Yvar` instead. + Will be removed in a future release (~v0.10). """ def __init__( diff --git a/botorch/models/utils/parse_training_data.py b/botorch/models/utils/parse_training_data.py index a1f13d60a8..165e7d3d3d 100644 --- a/botorch/models/utils/parse_training_data.py +++ b/botorch/models/utils/parse_training_data.py @@ -14,7 +14,7 @@ import torch from botorch.exceptions import UnsupportedError from botorch.models.model import Model -from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP +from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP from botorch.utils.datasets import RankingDataset, SupervisedDataset from botorch.utils.dispatcher import Dispatcher @@ -88,7 +88,7 @@ def _parse_model_dict( return dispatcher(consumer, next(iter(training_data.values()))) -@dispatcher.register((MultiTaskGP, FixedNoiseMultiTaskGP), dict) +@dispatcher.register(MultiTaskGP, dict) def _parse_multitask_dict( consumer: Model, training_data: Dict[Hashable, SupervisedDataset], diff --git a/botorch/utils/datasets.py b/botorch/utils/datasets.py index 6260d1a3ca..3638ce6810 100644 --- a/botorch/utils/datasets.py +++ b/botorch/utils/datasets.py @@ -154,6 +154,7 @@ class FixedNoiseDataset(SupervisedDataset): observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`. NOTE: This is deprecated. Use `SupervisedDataset` instead. + Will be removed in a future release (~v0.11). """ def __init__( diff --git a/docs/models.md b/docs/models.md index f914f9c12d..0b59bcf449 100644 --- a/docs/models.md +++ b/docs/models.md @@ -3,119 +3,123 @@ id: models title: Models --- -Models play an essential role in Bayesian Optimization (BO). A model is used as a -surrogate function for the actual underlying black box function to be optimized. -In BoTorch, a `Model` maps a set of design points to a posterior probability -distribution of its output(s) over the design points. - -In BO, the model used is traditionally a Gaussian Process (GP), -in which case the posterior distribution is a multivariate -normal. While BoTorch supports many GP models, **BoTorch makes no -assumption on the model being a GP** or the posterior being multivariate normal. -With the exception of some of the analytic acquisition functions in the +Models play an essential role in Bayesian Optimization (BO). A model is used as +a surrogate function for the actual underlying black box function to be +optimized. In BoTorch, a `Model` maps a set of design points to a posterior +probability distribution of its output(s) over the design points. + +In BO, the model used is traditionally a Gaussian Process (GP), in which case +the posterior distribution is a multivariate normal. While BoTorch supports many +GP models, **BoTorch makes no assumption on the model being a GP** or the +posterior being multivariate normal. With the exception of some of the analytic +acquisition functions in the [`botorch.acquisition.analytic`](../api/acquisition.html#analytic-acquisition-function-api) module, BoTorch’s Monte Carlo-based acquisition functions are compatible with -any model that conforms to the `Model` interface, whether user-implemented or provided. +any model that conforms to the `Model` interface, whether user-implemented or +provided. -Under the hood, BoTorch models are PyTorch `Modules` that implement -the light-weight [`Model`](../api/models.html#model-apis) interface. -When working with GPs, [`GPyTorchModel`](../api/models.html#module-botorch.models.gp_regression) +Under the hood, BoTorch models are PyTorch `Modules` that implement the +light-weight [`Model`](../api/models.html#model-apis) interface. When working +with GPs, +[`GPyTorchModel`](../api/models.html#module-botorch.models.gp_regression) provides a base class for conveniently wrapping GPyTorch models. -Users can extend `Model` and `GPyTorchModel` to generate their own models. -For more on implementing your own models, see +Users can extend `Model` and `GPyTorchModel` to generate their own models. For +more on implementing your own models, see [Implementing Custom Models](#implementing-custom-models) below. - ## Terminology ### Multi-Output and Multi-Task -A `Model` (as in the BoTorch object) may have -multiple outputs, multiple inputs, and may exploit correlation -between different inputs. BoTorch uses the following terminology to -distinguish these model types: - -* *Multi-Output Model*: a `Model` with multiple - outputs. Most BoTorch `Model`s are multi-output. -* *Multi-Task Model*: a `Model` making use of a logical grouping of + +A `Model` (as in the BoTorch object) may have multiple outputs, multiple inputs, +and may exploit correlation between different inputs. BoTorch uses the following +terminology to distinguish these model types: + +- _Multi-Output Model_: a `Model` with multiple outputs. Most BoTorch `Model`s + are multi-output. +- _Multi-Task Model_: a `Model` making use of a logical grouping of inputs/observations (as in the underlying process). For example, there could - be multiple tasks where each task has a different fidelity. - In a multi-task model, the relationship between different - outputs is modeled, with a joint model across tasks. + be multiple tasks where each task has a different fidelity. In a multi-task + model, the relationship between different outputs is modeled, with a joint + model across tasks. Note the following: -* A multi-task (MT) model may or may not be a multi-output model. -For example, if a multi-task model uses different tasks for modeling -but only outputs predictions for one of those tasks, it is single-output. -* Conversely, a multi-output (MO) model may or may not be a multi-task model. -For example, multi-output `Model`s that model -different outputs independently rather than -building a joint model are not multi-task. -* If a model is both, we refer to it as a multi-task-multi-output (MTMO) model. + +- A multi-task (MT) model may or may not be a multi-output model. For example, + if a multi-task model uses different tasks for modeling but only outputs + predictions for one of those tasks, it is single-output. +- Conversely, a multi-output (MO) model may or may not be a multi-task model. + For example, multi-output `Model`s that model different outputs independently + rather than building a joint model are not multi-task. +- If a model is both, we refer to it as a multi-task-multi-output (MTMO) model. ### Noise: Homoskedastic, fixed, and heteroskedastic + Noise can be treated in several different ways: -* *Homoskedastic*: Noise is not provided as an input and is inferred, with a -constant variance that does not depend on `X`. Many models, such as -`SingleTaskGP`, take this approach. Use these models if you know that -your observations are noisy, but not how noisy. +- _Homoskedastic_: Noise is not provided as an input and is inferred, with a + constant variance that does not depend on `X`. Many models, such as + `SingleTaskGP`, take this approach. Use these models if you know that your + observations are noisy, but not how noisy. -* *Fixed*: Noise is provided as an input and is not fit. In “fixed noise” models -like `FixedNoiseGP`, noise cannot be predicted out-of-sample because it has -not been modeled. Use these models if you have estimates of the noise in -your observations (e.g. observations may be averages over individual samples -in which case you would provide the mean as observation and the standard -error of the mean as the noise estimate), or if you know your observations are -noiseless (by passing a zero noise level). +- _Fixed_: Noise is provided as an input, `train_Yvar`, and is not fit. In + “fixed noise” models like `SingleTaskGP` with noise observations, noise cannot + be predicted out-of-sample because it has not been modeled. Use these models + if you have estimates of the noise in your observations (e.g. observations may + be averages over individual samples in which case you would provide the mean + as observation and the standard error of the mean as the noise estimate), or + if you know your observations are noiseless (by passing a zero noise level). -* *Heteroskedastic*: Noise is provided as an input and is modeled to allow for -predicting noise out-of-sample. Models like `HeteroskedasticSingleTaskGP` -take this approach. +- _Heteroskedastic_: Noise is provided as an input and is modeled to allow for + predicting noise out-of-sample. Models like `HeteroskedasticSingleTaskGP` take + this approach. ## Standard BoTorch Models BoTorch provides several GPyTorch models to cover most standard BO use cases: ### Single-Task GPs + These models use the same training data for all outputs and assume conditional independence of the outputs given the input. If different training data is -required for each output, use a [`ModelListGP`](../api/models.html#module-botorch.models.model_list_gp_regression) +required for each output, use a +[`ModelListGP`](../api/models.html#module-botorch.models.model_list_gp_regression) instead. -* [`SingleTaskGP`](../api/models.html#botorch.models.gp_regression.SingleTaskGP): a single-task - exact GP that infers a homoskedastic noise level (no noise observations). -* [`FixedNoiseGP`](../api/models.html#botorch.models.gp_regression.FixedNoiseGP): a single-task exact GP that - differs from `SingleTaskGP` in using - fixed observation noise levels. It requires noise observations. -* [`HeteroskedasticSingleTaskGP`](../api/models.html#botorch.models.gp_regression.HeteroskedasticSingleTaskGP): - a single-task exact GP that differs from `SingleTaskGP` and `FixedNoiseGP` - in that it models heteroskedastic noise using an additional - internal GP model. It requires noise observations. -* [`MixedSingleTaskGP`](../api/models.html#botorch.models.gp_regression_mixed.MixedSingleTaskGP): a single-task exact - GP that supports mixed search spaces, which combine discrete and continuous features. -* [`SaasFullyBayesianSingleTaskGP`](../api/models.html#botorch.models.fully_bayesian.SaasFullyBayesianSingleTaskGP): - a fully Bayesian single-task GP with the SAAS prior. This model is suitable for - sample-efficient high-dimensional Bayesian optimization. + +- [`SingleTaskGP`](../api/models.html#botorch.models.gp_regression.SingleTaskGP): + a single-task exact GP that supports both inferred and observed noise. When + noise observations are not provided, it infers a homoskedastic noise level. +- [`HeteroskedasticSingleTaskGP`](../api/models.html#botorch.models.gp_regression.HeteroskedasticSingleTaskGP): + a single-task exact GP that differs from `SingleTaskGP` with observed noise in + that it models heteroskedastic noise using an additional internal GP model. It + requires noise observations. +- [`MixedSingleTaskGP`](../api/models.html#botorch.models.gp_regression_mixed.MixedSingleTaskGP): + a single-task exact GP that supports mixed search spaces, which combine + discrete and continuous features. +- [`SaasFullyBayesianSingleTaskGP`](../api/models.html#botorch.models.fully_bayesian.SaasFullyBayesianSingleTaskGP): + a fully Bayesian single-task GP with the SAAS prior. This model is suitable + for sample-efficient high-dimensional Bayesian optimization. ### Model List of Single-Task GPs -* [`ModelListGP`](../api/models.html#module-botorch.models.model_list_gp_regression): A multi-output model in - which outcomes are modeled independently, given a list of any type of - single-task GP. This model should be used when the same training data is not - used for all outputs. + +- [`ModelListGP`](../api/models.html#module-botorch.models.model_list_gp_regression): + A multi-output model in which outcomes are modeled independently, given a list + of any type of single-task GP. This model should be used when the same + training data is not used for all outputs. ### Multi-Task GPs -* [`MultiTaskGP`](../api/models.html#module-botorch.models.multitask): a Hadamard multi-task, - multi-output GP using an ICM kernel, inferring a homoskedastic noise level (does not - require noise observations). -* [`FixedNoiseMultiTaskGP`](../api/models.html#botorch.models.multitask.FixedNoiseMultiTaskGP): - a Hadamard multi-task, multi-output GP using an ICM kernel, with fixed - observation noise levels (requires noise observations). -* [`KroneckerMultiTaskGP`](../api/models.html#botorch.models.multitask.KroneckerMultiTaskGP): A multi-task, - multi-output GP using an ICM kernel, with Kronecker structure. Useful for - multi-fidelity optimization. -* [`SaasFullyBayesianMultiTaskGP`](../api/models.html#saasfullybayesianmultitaskgp): - a fully Bayesian multi-task GP using an ICM kernel. The data kernel uses the SAAS - prior to model high-dimensional parameter spaces. + +- [`MultiTaskGP`](../api/models.html#module-botorch.models.multitask): a + Hadamard multi-task, multi-output GP using an ICM kernel. Supports both known + observation noise levels and inferring a homoskedastic noise level (when noise + observations are not provided). +- [`KroneckerMultiTaskGP`](../api/models.html#botorch.models.multitask.KroneckerMultiTaskGP): + A multi-task, multi-output GP using an ICM kernel, with Kronecker structure. + Useful for multi-fidelity optimization. +- [`SaasFullyBayesianMultiTaskGP`](../api/models.html#saasfullybayesianmultitaskgp): + a fully Bayesian multi-task GP using an ICM kernel. The data kernel uses the + SAAS prior to model high-dimensional parameter spaces. All of the above models use Matérn 5/2 kernels with Automatic Relevance Discovery (ARD), and have reasonable priors on hyperparameters that make them @@ -124,59 +128,60 @@ cube** and the **observations are standardized** (zero mean, unit variance). ## Other useful models -* [`ModelList`](../api/models.html#botorch.models.model.ModelList): a multi-output model container - in which outcomes are modeled independently by individual `Model`s (as in `ModelListGP`, but the - component models do not all need to be GPyTorch models). -* [`SingleTaskMultiFidelityGP`](../api/models.html#botorch.models.gp_regression_fidelity.SingleTaskMultiFidelityGP) and - [`FixedNoiseMultiFidelityGP`](../api/models.html#botorch.models.gp_regression_fidelity.FixedNoiseMultiFidelityGP): - Models for multi-fidelity optimization. For more on Multi-Fidelity BO, see the - [tutorial](../tutorials/discrete_multi_fidelity_bo). -* [`HigherOrderGP`](../api/models.html#botorch.models.higher_order_gp.HigherOrderGP): A GP model with - matrix-valued predictions, such as images or grids of images. -* [`PairwiseGP`](../api/models.html#module-botorch.models.pairwise_gp): A probit-likelihood GP that - learns via pairwise comparison data, useful for preference learning. -* [`ApproximateGPyTorchModel`](../api/models.html#botorch.models.approximate_gp.ApproximateGPyTorchModel): for - efficient computation when data is large or responses are non-Gaussian. -* [Deterministic models](../api/models.html#module-botorch.models.deterministic), such as +- [`ModelList`](../api/models.html#botorch.models.model.ModelList): a + multi-output model container in which outcomes are modeled independently by + individual `Model`s (as in `ModelListGP`, but the component models do not all + need to be GPyTorch models). +- [`SingleTaskMultiFidelityGP`](../api/models.html#botorch.models.gp_regression_fidelity.SingleTaskMultiFidelityGP): + A GP model for multi-fidelity optimization. For more on Multi-Fidelity BO, see + the [tutorial](../tutorials/discrete_multi_fidelity_bo). +- [`HigherOrderGP`](../api/models.html#botorch.models.higher_order_gp.HigherOrderGP): + A GP model with matrix-valued predictions, such as images or grids of images. +- [`PairwiseGP`](../api/models.html#module-botorch.models.pairwise_gp): A + probit-likelihood GP that learns via pairwise comparison data, useful for + preference learning. +- [`ApproximateGPyTorchModel`](../api/models.html#botorch.models.approximate_gp.ApproximateGPyTorchModel): + for efficient computation when data is large or responses are non-Gaussian. +- [Deterministic models](../api/models.html#module-botorch.models.deterministic), + such as [`AffineDeterministicModel`](../api/models.html#botorch.models.deterministic.AffineDeterministicModel), [`AffineFidelityCostModel`](../api/models.html#botorch.models.cost.AffineFidelityCostModel), [`GenericDeterministicModel`](../api/models.html#botorch.models.deterministic.GenericDeterministicModel), and [`PosteriorMeanModel`](../api/models.html#botorch.models.deterministic.PosteriorMeanModel) - express known input-output relationships; they conform - to the BoTorch `Model` API, so they can easily be used in conjunction with other - BoTorch models. Deterministic models are - useful for multi-objective optimization with known objective - functions and for encoding cost functions for cost-aware acquisition. -* [`SingleTaskVariationalGP`](../api/models.html#botorch.models.approximate_gp.SingleTaskVariationalGP): an - approximate model for faster computation when you have a lot of data or your responses - are non-Gaussian. - + express known input-output relationships; they conform to the BoTorch `Model` + API, so they can easily be used in conjunction with other BoTorch models. + Deterministic models are useful for multi-objective optimization with known + objective functions and for encoding cost functions for cost-aware + acquisition. +- [`SingleTaskVariationalGP`](../api/models.html#botorch.models.approximate_gp.SingleTaskVariationalGP): + an approximate model for faster computation when you have a lot of data or + your responses are non-Gaussian. ## Implementing Custom Models The configurability of the above models is limited (for instance, it is not straightforward to use a different kernel). Doing so is an intentional design -decision -- we believe that having a few simple and easy-to-understand models for -basic use cases is more valuable than having a highly complex and configurable -model class whose implementation is difficult to understand. +decision -- we believe that having a few simple and easy-to-understand models +for basic use cases is more valuable than having a highly complex and +configurable model class whose implementation is difficult to understand. -Instead, we advocate that users implement their own models to cover -more specialized use cases. The light-weight nature of BoTorch's Model API makes -this easy to do. See the +Instead, we advocate that users implement their own models to cover more +specialized use cases. The light-weight nature of BoTorch's Model API makes this +easy to do. See the [Using a custom BoTorch model in Ax](../tutorials/custom_botorch_model_in_ax) tutorial for an example. The BoTorch `Model` interface is light-weight and easy to extend. The only requirement for using BoTorch's Monte-Carlo based acquisition functions is that -the model has a `posterior` method. It takes in a Tensor `X` of design points, and -returns a Posterior object describing the (joint) probability distribution of -the model output(s) over the design points in `X`. The `Posterior` object must -implement an `rsample()` method for sampling from the posterior of the model. -If you wish to use gradient-based optimization algorithms, the model should -allow back-propagating gradients through the samples to the model input. - -If you happen to implement a model that would be useful for other -researchers as well (and involves more than just swapping out the Matérn kernel -for an RBF kernel), please consider [contributing](getting_started#contributing) -this model to BoTorch. +the model has a `posterior` method. It takes in a Tensor `X` of design points, +and returns a Posterior object describing the (joint) probability distribution +of the model output(s) over the design points in `X`. The `Posterior` object +must implement an `rsample()` method for sampling from the posterior of the +model. If you wish to use gradient-based optimization algorithms, the model +should allow back-propagating gradients through the samples to the model input. + +If you happen to implement a model that would be useful for other researchers as +well (and involves more than just swapping out the Matérn kernel for an RBF +kernel), please consider [contributing](getting_started#contributing) this model +to BoTorch. diff --git a/test/acquisition/test_analytic.py b/test/acquisition/test_analytic.py index 375aca17a0..de3f9af653 100644 --- a/test/acquisition/test_analytic.py +++ b/test/acquisition/test_analytic.py @@ -30,10 +30,11 @@ ScalarizedPosteriorTransform, ) from botorch.exceptions import UnsupportedError -from botorch.models import FixedNoiseGP, SingleTaskGP +from botorch.models import SingleTaskGP from botorch.posteriors import GPyTorchPosterior from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood NEI_NOISE = [ @@ -782,7 +783,7 @@ def _get_model(self, dtype=torch.float): noise = torch.tensor(NEI_NOISE, device=self.device, dtype=dtype) train_y += noise train_yvar = torch.full_like(train_y, 0.25**2) - model = FixedNoiseGP(train_X=train_x, train_Y=train_y, train_Yvar=train_yvar) + model = SingleTaskGP(train_X=train_x, train_Y=train_y, train_Yvar=train_yvar) model.load_state_dict(state_dict) model.to(train_x) model.eval() @@ -798,7 +799,8 @@ def test_noisy_expected_improvement(self): # before assigning, check that the attributes exist self.assertTrue(hasattr(LogNEI, "model")) self.assertTrue(hasattr(LogNEI, "best_f")) - self.assertTrue(isinstance(LogNEI.model, FixedNoiseGP)) + self.assertIsInstance(LogNEI.model, SingleTaskGP) + self.assertIsInstance(LogNEI.model.likelihood, FixedNoiseGaussianLikelihood) LogNEI.model = nEI.model # let the two share their values and fantasies LogNEI.best_f = nEI.best_f @@ -837,11 +839,11 @@ def test_noisy_expected_improvement(self): # regime where the naive implementation looses accuracy. atol = 2e-5 if dtype == torch.float32 else 1e-12 rtol = atol - self.assertTrue( - torch.allclose(X_test.grad[0], X_test_log.grad[0], atol=atol, rtol=rtol) + self.assertAllClose( + X_test.grad[0], X_test_log.grad[0], atol=atol, rtol=rtol ) - # test non-FixedNoiseGP model + # test inferred noise model other_model = SingleTaskGP(X_observed, model.train_targets.unsqueeze(-1)) for constructor in ( NoisyExpectedImprovement, diff --git a/test/acquisition/test_input_constructors.py b/test/acquisition/test_input_constructors.py index 53b36c9aad..30ce1e854b 100644 --- a/test/acquisition/test_input_constructors.py +++ b/test/acquisition/test_input_constructors.py @@ -80,7 +80,7 @@ project_to_target_fidelity, ) from botorch.exceptions.errors import UnsupportedError -from botorch.models import FixedNoiseGP, MultiTaskGP, SingleTaskGP +from botorch.models import MultiTaskGP, SingleTaskGP from botorch.models.deterministic import FixedSingleSampleModel from botorch.models.model_list_gp_regression import ModelListGP from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler @@ -357,7 +357,7 @@ def test_construct_inputs_noisy_ei(self) -> None: for acqf_cls in [NoisyExpectedImprovement, LogNoisyExpectedImprovement]: with self.subTest(acqf_cls=acqf_cls): c = get_acqf_input_constructor(acqf_cls) - mock_model = FixedNoiseGP( + mock_model = SingleTaskGP( train_X=torch.rand((2, 2)), train_Y=torch.rand((2, 1)), train_Yvar=torch.rand((2, 1)), @@ -1269,7 +1269,7 @@ def test_constructors_like_ExpectedImprovement(self) -> None: qLogNoisyExpectedImprovement, qProbabilityOfImprovement, ] - model = FixedNoiseGP( + model = SingleTaskGP( train_X=torch.rand((4, 2)), train_Y=torch.rand((4, 1)), train_Yvar=torch.ones((4, 1)), diff --git a/test/models/test_contextual.py b/test/models/test_contextual.py index 758a084907..54b41d50aa 100644 --- a/test/models/test_contextual.py +++ b/test/models/test_contextual.py @@ -10,48 +10,59 @@ import torch from botorch.fit import fit_gpytorch_mll from botorch.models.contextual import LCEAGP, SACGP -from botorch.models.gp_regression import FixedNoiseGP +from botorch.models.gp_regression import SingleTaskGP from botorch.models.kernels.contextual_lcea import LCEAKernel from botorch.models.kernels.contextual_sac import SACKernel from botorch.utils.datasets import SupervisedDataset from botorch.utils.testing import BotorchTestCase from gpytorch.distributions.multivariate_normal import MultivariateNormal +from gpytorch.likelihoods.gaussian_likelihood import ( + FixedNoiseGaussianLikelihood, + GaussianLikelihood, +) from gpytorch.means import ConstantMean from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from torch import Tensor def _gen_datasets( + infer_noise: bool = False, **tkwargs, ) -> Tuple[Dict[int, SupervisedDataset], Tuple[Tensor, Tensor, Tensor]]: train_X = torch.tensor( [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]], **tkwargs ) train_Y = torch.tensor([[1.0], [2.0], [3.0]], **tkwargs) - train_Yvar = 0.01 * torch.ones(3, 1, **tkwargs) + train_Yvar = None if infer_noise else torch.full_like(train_Y, 0.01) datasets = SupervisedDataset( X=train_X, Y=train_Y, + Yvar=train_Yvar, feature_names=[f"x{i}" for i in range(train_X.shape[-1])], outcome_names=["y"], - Yvar=train_Yvar, ) return datasets, (train_X, train_Y, train_Yvar) class TestContextualGP(BotorchTestCase): def test_SACGP(self): - for dtype in (torch.float, torch.double): + for dtype, infer_noise in ((torch.float, False), (torch.double, True)): tkwargs = {"device": self.device, "dtype": dtype} - datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs) + datasets, (train_X, train_Y, train_Yvar) = _gen_datasets( + infer_noise, **tkwargs + ) self.decomposition = {"1": [0, 3], "2": [1, 2]} model = SACGP(train_X, train_Y, train_Yvar, self.decomposition) mll = ExactMarginalLogLikelihood(model.likelihood, model) fit_gpytorch_mll(mll, optimizer_kwargs={"options": {"maxiter": 1}}) - self.assertIsInstance(model, FixedNoiseGP) + self.assertIsInstance(model, SingleTaskGP) + self.assertIsInstance( + model.likelihood, + GaussianLikelihood if infer_noise else FixedNoiseGaussianLikelihood, + ) self.assertDictEqual(model.decomposition, self.decomposition) self.assertIsInstance(model.mean_module, ConstantMean) self.assertIsInstance(model.covar_module, SACKernel) @@ -90,13 +101,14 @@ def test_SACGP_construct_inputs(self): self.assertTrue(train_Yvar.equal(data_dict["train_Yvar"])) self.assertDictEqual(data_dict["decomposition"], self.decomposition) - def testLCEAGP(self): - for dtype in (torch.float, torch.double): + def test_LCEAGP(self): + for dtype, infer_noise in ((torch.float, False), (torch.double, True)): tkwargs = {"device": self.device, "dtype": dtype} - datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs) + datasets, (train_X, train_Y, train_Yvar) = _gen_datasets( + infer_noise, **tkwargs + ) # Test setting attributes decomposition = {"1": [0, 1], "2": [2, 3]} - # test instantiate model model = LCEAGP( train_X=train_X, @@ -110,6 +122,10 @@ def testLCEAGP(self): self.assertIsInstance(model, LCEAGP) self.assertIsInstance(model.covar_module, LCEAKernel) self.assertDictEqual(model.decomposition, decomposition) + self.assertIsInstance( + model.likelihood, + GaussianLikelihood if infer_noise else FixedNoiseGaussianLikelihood, + ) test_x = torch.rand(5, 4, device=self.device, dtype=dtype) posterior = model(test_x) diff --git a/test/models/test_converter.py b/test/models/test_converter.py index 7ae34bcd50..bb83173ae7 100644 --- a/test/models/test_converter.py +++ b/test/models/test_converter.py @@ -8,7 +8,6 @@ import torch from botorch.exceptions import UnsupportedError from botorch.models import ( - FixedNoiseGP, HeteroskedasticSingleTaskGP, ModelListGP, SingleTaskGP, @@ -24,6 +23,7 @@ from botorch.utils.testing import BotorchTestCase from gpytorch.kernels import RBFKernel from gpytorch.likelihoods import GaussianLikelihood +from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from .test_gpytorch import SimpleGPyTorchModel @@ -39,10 +39,14 @@ def test_batched_to_model_list(self): batch_gp = SingleTaskGP(train_X, train_Y) list_gp = batched_to_model_list(batch_gp) self.assertIsInstance(list_gp, ModelListGP) - # test FixedNoiseGP - batch_gp = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y)) + self.assertIsInstance(list_gp.models[0].likelihood, GaussianLikelihood) + # test observed noise + batch_gp = SingleTaskGP(train_X, train_Y, torch.rand_like(train_Y)) list_gp = batched_to_model_list(batch_gp) self.assertIsInstance(list_gp, ModelListGP) + self.assertIsInstance( + list_gp.models[0].likelihood, FixedNoiseGaussianLikelihood + ) # test SingleTaskMultiFidelityGP for lin_trunc in (False, True): batch_gp = SingleTaskMultiFidelityGP( @@ -103,11 +107,12 @@ def test_model_list_to_batched(self): list_gp = ModelListGP(gp1, gp2) batch_gp = model_list_to_batched(list_gp) self.assertIsInstance(batch_gp, SingleTaskGP) + self.assertIsInstance(batch_gp.likelihood, GaussianLikelihood) # test degenerate (single model) batch_gp = model_list_to_batched(ModelListGP(gp1)) self.assertEqual(batch_gp._num_outputs, 1) - # test different model classes - gp2 = FixedNoiseGP(train_X, train_Y1, torch.ones_like(train_Y1)) + # test mixing different likelihoods + gp2 = SingleTaskGP(train_X, train_Y1, torch.ones_like(train_Y1)) with self.assertRaises(UnsupportedError): model_list_to_batched(ModelListGP(gp1, gp2)) # test non-batched models @@ -158,14 +163,15 @@ def test_model_list_to_batched(self): list_gp = ModelListGP(gp1, gp2) with self.assertRaises(UnsupportedError): model_list_to_batched(list_gp) - # test FixedNoiseGP + # test observed noise train_X = torch.rand(10, 2, device=self.device, dtype=dtype) train_Y1 = train_X.sum(dim=-1, keepdim=True) train_Y2 = (train_X[:, 0] - train_X[:, 1]).unsqueeze(-1) - gp1_ = FixedNoiseGP(train_X, train_Y1, torch.rand_like(train_Y1)) - gp2_ = FixedNoiseGP(train_X, train_Y2, torch.rand_like(train_Y2)) + gp1_ = SingleTaskGP(train_X, train_Y1, torch.rand_like(train_Y1)) + gp2_ = SingleTaskGP(train_X, train_Y2, torch.rand_like(train_Y2)) list_gp = ModelListGP(gp1_, gp2_) batch_gp = model_list_to_batched(list_gp) + self.assertIsInstance(batch_gp.likelihood, FixedNoiseGaussianLikelihood) # test SingleTaskMultiFidelityGP gp1_ = SingleTaskMultiFidelityGP(train_X, train_Y1, iteration_fidelity=1) gp2_ = SingleTaskMultiFidelityGP(train_X, train_Y2, iteration_fidelity=1) @@ -249,8 +255,8 @@ def test_roundtrip(self): sd_recov = batch_gp_recov.state_dict() self.assertTrue(set(sd_orig) == set(sd_recov)) self.assertTrue(all(torch.equal(sd_orig[k], sd_recov[k]) for k in sd_orig)) - # FixedNoiseGP - batch_gp = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y)) + # Observed noise + batch_gp = SingleTaskGP(train_X, train_Y, torch.rand_like(train_Y)) list_gp = batched_to_model_list(batch_gp) batch_gp_recov = model_list_to_batched(list_gp) sd_orig = batch_gp.state_dict() @@ -299,11 +305,14 @@ def test_batched_multi_output_to_single_output(self): gp2 = SingleTaskGP(train_X, train_Y, likelihood=GaussianLikelihood()) with self.assertRaises(NotImplementedError): batched_multi_output_to_single_output(gp2) - # test FixedNoiseGP + # test observed noise train_X = torch.rand(10, 2, device=self.device, dtype=dtype) - batched_mo_model = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y)) + batched_mo_model = SingleTaskGP(train_X, train_Y, torch.rand_like(train_Y)) batched_so_model = batched_multi_output_to_single_output(batched_mo_model) - self.assertIsInstance(batched_so_model, FixedNoiseGP) + self.assertIsInstance(batched_so_model, SingleTaskGP) + self.assertIsInstance( + batched_so_model.likelihood, FixedNoiseGaussianLikelihood + ) self.assertEqual(batched_so_model.num_outputs, 1) # test SingleTaskMultiFidelityGP batched_mo_model = SingleTaskMultiFidelityGP( diff --git a/test/models/test_gp_regression.py b/test/models/test_gp_regression.py index 289b08bc64..40d6c94b7a 100644 --- a/test/models/test_gp_regression.py +++ b/test/models/test_gp_regression.py @@ -220,14 +220,14 @@ def test_condition_on_observations(self): ) c_kwargs = ( {"noise": torch.full_like(Y_fant, 0.01)} - if isinstance(model, FixedNoiseGP) + if isinstance(model.likelihood, FixedNoiseGaussianLikelihood) else {} ) cm = model.condition_on_observations(X_fant, Y_fant, **c_kwargs) # fantasize at same input points (check proper broadcasting) c_kwargs_same_inputs = ( {"noise": torch.full_like(Y_fant[0], 0.01)} - if isinstance(model, FixedNoiseGP) + if isinstance(model.likelihood, FixedNoiseGaussianLikelihood) else {} ) cm_same_inputs = model.condition_on_observations( @@ -277,7 +277,7 @@ def test_condition_on_observations(self): model_non_batch.posterior(torch.rand(torch.Size([4, 1]), **tkwargs)) c_kwargs = ( {"noise": torch.full_like(Y_fant[0, 0, :], 0.01)} - if isinstance(model, FixedNoiseGP) + if isinstance(model.likelihood, FixedNoiseGaussianLikelihood) else {} ) cm_non_batch = model_non_batch.condition_on_observations( @@ -399,6 +399,8 @@ def test_set_transformed_inputs(self): class TestFixedNoiseGP(TestSingleTaskGP): + model_class = FixedNoiseGP + def _get_model_and_data( self, batch_shape, @@ -417,7 +419,14 @@ def _get_model_and_data( "input_transform": input_transform, "outcome_transform": outcome_transform, } - model = FixedNoiseGP(**model_kwargs, **extra_model_kwargs) + if self.model_class is FixedNoiseGP: + with self.assertWarnsRegex( + DeprecationWarning, + "`FixedNoiseGP` has been merged into `SingleTaskGP`. ", + ): + model = FixedNoiseGP(**model_kwargs, **extra_model_kwargs) + else: + model = self.model_class(**model_kwargs, **extra_model_kwargs) return model, model_kwargs def _get_extra_model_kwargs(self): @@ -528,6 +537,11 @@ def test_fantasized_noise(self): ) +class TestFixedNoiseSingleTaskGP(TestFixedNoiseGP): + # Repeat the FixedNoiseGP tests using SingleTaskGP. + model_class = SingleTaskGP + + class TestHeteroskedasticSingleTaskGP(TestSingleTaskGP): def _get_model_and_data( self, batch_shape, m, outcome_transform=None, input_transform=None, **tkwargs diff --git a/test/models/test_gp_regression_fidelity.py b/test/models/test_gp_regression_fidelity.py index 66b41d0d9c..42c511f6f4 100644 --- a/test/models/test_gp_regression_fidelity.py +++ b/test/models/test_gp_regression_fidelity.py @@ -12,7 +12,6 @@ from botorch.exceptions.errors import UnsupportedError from botorch.exceptions.warnings import OptimizationWarning from botorch.fit import fit_gpytorch_mll -from botorch.models.gp_regression import FixedNoiseGP from botorch.models.gp_regression_fidelity import ( FixedNoiseMultiFidelityGP, SingleTaskMultiFidelityGP, @@ -243,14 +242,14 @@ def test_condition_on_observations(self): ) c_kwargs = ( {"noise": torch.full_like(Y_fant, 0.01)} - if isinstance(model, FixedNoiseGP) + if isinstance(model.likelihood, FixedNoiseGaussianLikelihood) else {} ) cm = model.condition_on_observations(X_fant, Y_fant, **c_kwargs) # fantasize at different same input points c_kwargs_same_inputs = ( {"noise": torch.full_like(Y_fant[0], 0.01)} - if isinstance(model, FixedNoiseGP) + if isinstance(model.likelihood, FixedNoiseGaussianLikelihood) else {} ) cm_same_inputs = model.condition_on_observations( @@ -309,7 +308,9 @@ def test_condition_on_observations(self): ) c_kwargs = ( {"noise": torch.full_like(Y_fant[0, 0, :], 0.01)} - if isinstance(model, FixedNoiseGP) + if isinstance( + model.likelihood, FixedNoiseGaussianLikelihood + ) else {} ) mnb = model_non_batch @@ -435,6 +436,8 @@ def test_construct_inputs(self): class TestFixedNoiseMultiFidelityGP(TestSingleTaskMultiFidelityGP): + model_class = FixedNoiseMultiFidelityGP + def _get_model_and_data( self, iteration_fidelity, @@ -468,7 +471,11 @@ def _get_model_and_data( model_kwargs["outcome_transform"] = outcome_transform if input_transform is not None: model_kwargs["input_transform"] = input_transform - model = FixedNoiseMultiFidelityGP(**model_kwargs) + if self.model_class is FixedNoiseMultiFidelityGP: + with self.assertWarnsRegex(DeprecationWarning, "SingleTaskMultiFidelityGP"): + model = FixedNoiseMultiFidelityGP(**model_kwargs) + else: + model = self.model_class(**model_kwargs) return model, model_kwargs def test_init_error(self): @@ -558,3 +565,8 @@ def test_construct_inputs(self): self.assertEqual(data_dict.get("data_fidelities", None), [1]) self.assertTrue(kwargs["train_X"].equal(data_dict["train_X"])) self.assertTrue(kwargs["train_Y"].equal(data_dict["train_Y"])) + + +class TestFixedNoiseSingleTaskMultiFidelityGP(TestFixedNoiseMultiFidelityGP): + # Test SingleTaskMultiFidelityGP with observed noise. + model_class = SingleTaskMultiFidelityGP diff --git a/test/models/test_gp_regression_mixed.py b/test/models/test_gp_regression_mixed.py index 2de7abcdf0..6406189080 100644 --- a/test/models/test_gp_regression_mixed.py +++ b/test/models/test_gp_regression_mixed.py @@ -8,7 +8,7 @@ import warnings import torch -from botorch.exceptions.warnings import InputDataWarning, OptimizationWarning +from botorch.exceptions.warnings import OptimizationWarning from botorch.fit import fit_gpytorch_mll from botorch.models.converter import batched_to_model_list from botorch.models.gp_regression_mixed import MixedSingleTaskGP @@ -21,6 +21,8 @@ from gpytorch.kernels.kernel import AdditiveKernel, ProductKernel from gpytorch.kernels.matern_kernel import MaternKernel from gpytorch.kernels.scale_kernel import ScaleKernel +from gpytorch.likelihoods import FixedNoiseGaussianLikelihood +from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood from gpytorch.means import ConstantMean from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood @@ -28,14 +30,15 @@ class TestMixedSingleTaskGP(BotorchTestCase): + observed_noise = False + def test_gp(self): d = 3 bounds = torch.tensor([[-1.0] * d, [1.0] * d]) - for batch_shape, m, ncat, dtype in itertools.product( - (torch.Size(), torch.Size([2])), - (1, 2), - (0, 1, 3), - (torch.float, torch.double), + for batch_shape, m, ncat, dtype, observed_noise in ( + (torch.Size(), 1, 0, torch.float, False), + (torch.Size(), 2, 1, torch.double, True), + (torch.Size([2]), 2, 3, torch.double, False), ): tkwargs = {"device": self.device, "dtype": dtype} train_X, train_Y = _get_random_data( @@ -62,7 +65,13 @@ def test_gp(self): MixedSingleTaskGP(train_X, train_Y, cat_dims=cat_dims) continue - model = MixedSingleTaskGP(train_X, train_Y, cat_dims=cat_dims) + train_Yvar = torch.full_like(train_Y, 0.1) if observed_noise else None + model = MixedSingleTaskGP( + train_X=train_X, + train_Y=train_Y, + cat_dims=cat_dims, + train_Yvar=train_Yvar, + ) self.assertEqual(model._ignore_X_dims_scaling_check, cat_dims) mll = ExactMarginalLogLikelihood(model.likelihood, model).to(**tkwargs) with warnings.catch_warnings(): @@ -90,6 +99,10 @@ def test_gp(self): else: self.assertIsInstance(model.covar_module, ScaleKernel) self.assertIsInstance(model.covar_module.base_kernel, CategoricalKernel) + if observed_noise: + self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood) + else: + self.assertIsInstance(model.likelihood, GaussianLikelihood) # test posterior # test non batch evaluation @@ -127,20 +140,24 @@ def test_gp(self): with self.assertRaisesRegex(NotImplementedError, "not supported"): batched_to_model_list(model) - def test_condition_on_observations(self): + def test_condition_on_observations__(self): d = 3 - for batch_shape, m, ncat, dtype in itertools.product( - (torch.Size(), torch.Size([2])), - (1, 2), - (1, 2), - (torch.float, torch.double), + for batch_shape, m, ncat, dtype, observed_noise in ( + (torch.Size(), 2, 1, torch.float, True), + (torch.Size([2]), 1, 2, torch.double, False), ): tkwargs = {"device": self.device, "dtype": dtype} train_X, train_Y = _get_random_data( batch_shape=batch_shape, m=m, d=d, **tkwargs ) cat_dims = list(range(ncat)) - model = MixedSingleTaskGP(train_X, train_Y, cat_dims=cat_dims) + train_Yvar = torch.full_like(train_Y, 0.1) if observed_noise else None + model = MixedSingleTaskGP( + train_X=train_X, + train_Y=train_Y, + cat_dims=cat_dims, + train_Yvar=train_Yvar, + ) # evaluate model model.posterior(torch.rand(torch.Size([4, d]), **tkwargs)) @@ -151,11 +168,16 @@ def test_condition_on_observations(self): X_fant, Y_fant = _get_random_data( fant_shape + batch_shape, m=m, d=d, n=3, **tkwargs ) - cm = model.condition_on_observations(X_fant, Y_fant) + additional_kwargs = ( + {"noise": torch.full_like(Y_fant, 0.1)} if observed_noise else {} + ) + cm = model.condition_on_observations(X_fant, Y_fant, **additional_kwargs) # fantasize at same input points (check proper broadcasting) + additional_kwargs = ( + {"noise": torch.full_like(Y_fant[0], 0.1)} if observed_noise else {} + ) cm_same_inputs = model.condition_on_observations( - X_fant[0], - Y_fant, + X_fant[0], Y_fant, **additional_kwargs ) test_Xs = [ @@ -189,14 +211,20 @@ def test_condition_on_observations(self): "train_Y": train_Y[0], "cat_dims": cat_dims, } + if observed_noise: + model_kwargs_non_batch["train_Yvar"] = train_Yvar[0] model_non_batch = type(model)(**model_kwargs_non_batch) model_non_batch.load_state_dict(state_dict_non_batch) model_non_batch.eval() model_non_batch.likelihood.eval() model_non_batch.posterior(torch.rand(torch.Size([4, d]), **tkwargs)) + additional_kwargs = ( + {"noise": torch.full_like(Y_fant, 0.1)} + if observed_noise + else {} + ) cm_non_batch = model_non_batch.condition_on_observations( - X_fant[0][0], - Y_fant[:, 0, :], + X_fant[0][0], Y_fant[:, 0, :], **additional_kwargs ) non_batch_posterior = cm_non_batch.posterior(test_X) self.assertTrue( @@ -218,18 +246,22 @@ def test_condition_on_observations(self): def test_fantasize(self): d = 3 - for batch_shape, m, ncat, dtype in itertools.product( - (torch.Size(), torch.Size([2])), - (1, 2), - (1, 2), - (torch.float, torch.double), + for batch_shape, m, ncat, dtype, observed_noise in ( + (torch.Size(), 2, 1, torch.float, True), + (torch.Size([2]), 1, 2, torch.double, False), ): tkwargs = {"device": self.device, "dtype": dtype} train_X, train_Y = _get_random_data( batch_shape=batch_shape, m=m, d=d, **tkwargs ) + train_Yvar = torch.full_like(train_Y, 0.1) if observed_noise else None cat_dims = list(range(ncat)) - model = MixedSingleTaskGP(train_X, train_Y, cat_dims=cat_dims) + model = MixedSingleTaskGP( + train_X=train_X, + train_Y=train_Y, + cat_dims=cat_dims, + train_Yvar=train_Yvar, + ) # fantasize X_f = torch.rand(torch.Size(batch_shape + torch.Size([4, d])), **tkwargs) @@ -295,10 +327,9 @@ def test_construct_inputs(self): feature_names=[f"x{i}" for i in range(d)], outcome_names=["y"], ) - with self.assertWarnsRegex(InputDataWarning, "train_Yvar"): - model_kwargs = MixedSingleTaskGP.construct_inputs( - training_data, categorical_features=cat_dims - ) + model_kwargs = MixedSingleTaskGP.construct_inputs( + training_data, categorical_features=cat_dims + ) self.assertTrue(X.equal(model_kwargs["train_X"])) self.assertTrue(Y.equal(model_kwargs["train_Y"])) - self.assertNotIn("train_Yvar", model_kwargs) + self.assertTrue(Y.equal(model_kwargs["train_Yvar"])) diff --git a/test/models/test_model_list_gp_regression.py b/test/models/test_model_list_gp_regression.py index fb9c80c535..d5a1d84df3 100644 --- a/test/models/test_model_list_gp_regression.py +++ b/test/models/test_model_list_gp_regression.py @@ -13,8 +13,8 @@ from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.exceptions.warnings import OptimizationWarning from botorch.fit import fit_gpytorch_mll -from botorch.models import ModelListGP -from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP +from botorch.models.gp_regression import SingleTaskGP +from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import ChainedOutcomeTransform, Log, Standardize @@ -26,6 +26,10 @@ from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from gpytorch.kernels import MaternKernel, ScaleKernel from gpytorch.likelihoods import LikelihoodList +from gpytorch.likelihoods.gaussian_likelihood import ( + FixedNoiseGaussianLikelihood, + GaussianLikelihood, +) from gpytorch.means import ConstantMean from gpytorch.mlls import SumMarginalLogLikelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood @@ -65,33 +69,23 @@ def _get_model( if fixed_noise: train_y1_var = 0.1 + 0.1 * torch.rand_like(train_y1, **tkwargs) train_y2_var = 0.1 + 0.1 * torch.rand_like(train_y2, **tkwargs) - model1 = FixedNoiseGP( - train_X=train_x1, - train_Y=train_y1, - train_Yvar=train_y1_var, - outcome_transform=octfs[0], - input_transform=intfs[0], - ) - model2 = FixedNoiseGP( - train_X=train_x2, - train_Y=train_y2, - train_Yvar=train_y2_var, - outcome_transform=octfs[1], - input_transform=intfs[1], - ) else: - model1 = SingleTaskGP( - train_X=train_x1, - train_Y=train_y1, - outcome_transform=octfs[0], - input_transform=intfs[0], - ) - model2 = SingleTaskGP( - train_X=train_x2, - train_Y=train_y2, - outcome_transform=octfs[1], - input_transform=intfs[1], - ) + train_y1_var = None + train_y2_var = None + model1 = SingleTaskGP( + train_X=train_x1, + train_Y=train_y1, + train_Yvar=train_y1_var, + outcome_transform=octfs[0], + input_transform=intfs[0], + ) + model2 = SingleTaskGP( + train_X=train_x2, + train_Y=train_y2, + train_Yvar=train_y2_var, + outcome_transform=octfs[1], + input_transform=intfs[1], + ) model = ModelListGP(model1, model2) return model.to(**tkwargs) @@ -379,18 +373,15 @@ def test_transform_revert_train_inputs(self): self.assertTrue(torch.equal(m._original_train_inputs, org_inputs[i])) def test_fantasize(self): - for model_cls in (SingleTaskGP, FixedNoiseGP): + for fixed_noise in (False, True): x1 = torch.rand(5, 2) y1 = torch.rand(5, 1) x2 = torch.rand(5, 2) y2 = torch.rand(5, 1) - m1_kwargs = {} - m2_kwargs = {} - if model_cls is FixedNoiseGP: - m1_kwargs = {"train_Yvar": torch.full_like(y1, 0.1)} - m2_kwargs = {"train_Yvar": torch.full_like(y2, 0.2)} - m1 = model_cls(x1, y1, **m1_kwargs).eval() - m2 = model_cls(x2, y2, **m2_kwargs).eval() + yvar1 = torch.full_like(y1, 0.1) if fixed_noise else None + yvar2 = torch.full_like(y2, 0.2) if fixed_noise else None + m1 = SingleTaskGP(x1, y1, yvar1).eval() + m2 = SingleTaskGP(x2, y2, yvar2).eval() modellist = ModelListGP(m1, m2) fm = modellist.fantasize( torch.rand(3, 2), sampler=IIDNormalSampler(sample_shape=torch.Size([2])) @@ -398,7 +389,11 @@ def test_fantasize(self): self.assertIsInstance(fm, ModelListGP) for i in range(2): fm_i = fm.models[i] - self.assertIsInstance(fm_i, model_cls) + self.assertIsInstance(fm_i, SingleTaskGP) + self.assertIsInstance( + fm_i.likelihood, + FixedNoiseGaussianLikelihood if fixed_noise else GaussianLikelihood, + ) self.assertEqual(fm_i.train_inputs[0].shape, torch.Size([2, 8, 2])) self.assertEqual(fm_i.train_targets.shape, torch.Size([2, 8])) @@ -418,14 +413,18 @@ def test_fantasize(self): self.assertIsInstance(fm, ModelListGP) for i in range(2): fm_i = fm.models[i] - self.assertIsInstance(fm_i, model_cls) + self.assertIsInstance(fm_i, SingleTaskGP) + self.assertIsInstance( + fm_i.likelihood, + FixedNoiseGaussianLikelihood if fixed_noise else GaussianLikelihood, + ) num_points = 7 - i self.assertEqual( fm_i.train_inputs[0].shape, torch.Size([2, num_points, 2]) ) self.assertEqual(fm_i.train_targets.shape, torch.Size([2, num_points])) # test decoupled with observation_noise - if model_cls is FixedNoiseGP: + if fixed_noise: # already transformed observation_noise = torch.full( (3, 2), 0.3, dtype=x1.dtype, device=x1.device @@ -440,7 +439,8 @@ def test_fantasize(self): self.assertIsInstance(fm, ModelListGP) for i in range(2): fm_i = fm.models[i] - self.assertIsInstance(fm_i, model_cls) + self.assertIsInstance(fm_i, SingleTaskGP) + self.assertIsInstance(fm_i.likelihood, FixedNoiseGaussianLikelihood) num_points = 7 - i self.assertEqual( fm_i.train_inputs[0].shape, torch.Size([2, num_points, 2]) @@ -598,8 +598,8 @@ def test_fantasize_with_outcome_transform_fixed_noise(self) -> None: yvar = torch.full_like(Y, 1e-4) yvar2 = 2 * yvar model = ModelListGP( - FixedNoiseGP(X, Y, yvar, outcome_transform=Standardize(m=1)), - FixedNoiseGP(X, Y2, yvar2, outcome_transform=Standardize(m=1)), + SingleTaskGP(X, Y, yvar, outcome_transform=Standardize(m=1)), + SingleTaskGP(X, Y2, yvar2, outcome_transform=Standardize(m=1)), ) # test exceptions eval_mask = torch.zeros( @@ -611,7 +611,6 @@ def test_fantasize_with_outcome_transform_fixed_noise(self) -> None: f"{' x '.join(str(i) for i in eval_mask.shape)}`." ) with self.assertRaisesRegex(BotorchTensorDimensionError, msg): - model.fantasize( X, evaluation_mask=eval_mask, diff --git a/test/models/utils/test_parse_training_data.py b/test/models/utils/test_parse_training_data.py index 2e150742ad..c8f986b03b 100644 --- a/test/models/utils/test_parse_training_data.py +++ b/test/models/utils/test_parse_training_data.py @@ -6,13 +6,12 @@ import torch from botorch.exceptions import UnsupportedError -from botorch.models.gp_regression import FixedNoiseGP from botorch.models.model import Model from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP from botorch.models.utils.parse_training_data import parse_training_data from botorch.utils.containers import SliceContainer -from botorch.utils.datasets import FixedNoiseDataset, RankingDataset, SupervisedDataset +from botorch.utils.datasets import RankingDataset, SupervisedDataset from botorch.utils.testing import BotorchTestCase from torch import cat, long, rand, Size, tensor @@ -32,26 +31,17 @@ def test_supervised(self): self.assertIsInstance(parse, dict) self.assertTrue(torch.equal(dataset.X, parse["train_X"])) self.assertTrue(torch.equal(dataset.Y, parse["train_Y"])) - - def test_fixedNoise(self): - # Test passing a `SupervisedDataset` - dataset = SupervisedDataset( - X=rand(3, 2), Y=rand(3, 1), feature_names=["a", "b"], outcome_names=["y"] - ) - parse = parse_training_data(FixedNoiseGP, dataset) self.assertTrue("train_Yvar" not in parse) - self.assertTrue(torch.equal(dataset.X, parse["train_X"])) - self.assertTrue(torch.equal(dataset.Y, parse["train_Y"])) - # Test passing a `FixedNoiseDataset` - dataset = FixedNoiseDataset( + # Test with noise + dataset = SupervisedDataset( X=rand(3, 2), Y=rand(3, 1), Yvar=rand(3, 1), feature_names=["a", "b"], outcome_names=["y"], ) - parse = parse_training_data(FixedNoiseGP, dataset) + parse = parse_training_data(Model, dataset) self.assertTrue(torch.equal(dataset.X, parse["train_X"])) self.assertTrue(torch.equal(dataset.Y, parse["train_Y"])) self.assertTrue(torch.equal(dataset.Yvar, parse["train_Yvar"])) diff --git a/test/sampling/pathwise/test_posterior_samplers.py b/test/sampling/pathwise/test_posterior_samplers.py index d6b3ca6fc2..8828e34f63 100644 --- a/test/sampling/pathwise/test_posterior_samplers.py +++ b/test/sampling/pathwise/test_posterior_samplers.py @@ -6,17 +6,10 @@ from __future__ import annotations -from collections import defaultdict from copy import deepcopy -from itertools import product import torch -from botorch.models import ( - FixedNoiseGP, - ModelListGP, - SingleTaskGP, - SingleTaskVariationalGP, -) +from botorch.models import ModelListGP, SingleTaskGP, SingleTaskVariationalGP from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize from botorch.sampling.pathwise import draw_matheron_paths, MatheronPath, PathList @@ -32,82 +25,61 @@ class TestPosteriorSamplers(BotorchTestCase): def setUp(self) -> None: super().setUp() - self.models = defaultdict(list) + tkwargs = {"device": self.device, "dtype": torch.float64} + torch.manual_seed(0) + + base = MaternKernel(nu=2.5, ard_num_dims=2, batch_shape=Size([])) + base.lengthscale = 0.1 + 0.3 * torch.rand_like(base.lengthscale) + kernel = ScaleKernel(base) + kernel.to(**tkwargs) + + uppers = 1 + 9 * torch.rand(base.lengthscale.shape[-1], **tkwargs) + bounds = pad(uppers.unsqueeze(0), (0, 0, 1, 0)) + X = uppers * torch.rand(4, base.lengthscale.shape[-1], **tkwargs) + Y = 10 * kernel(X).cholesky() @ torch.randn(4, 1, **tkwargs) + input_transform = Normalize(d=X.shape[-1], bounds=bounds) + outcome_transform = Standardize(m=Y.shape[-1]) + + # SingleTaskGP w/ inferred noise in eval mode + self.inferred_noise_gp = SingleTaskGP( + train_X=X, + train_Y=Y, + covar_module=deepcopy(kernel), + input_transform=deepcopy(input_transform), + outcome_transform=deepcopy(outcome_transform), + ).eval() + + # SingleTaskGP with observed noise in train mode + self.observed_noise_gp = SingleTaskGP( + train_X=X, + train_Y=Y, + train_Yvar=0.01 * torch.rand_like(Y), + covar_module=kernel, + input_transform=input_transform, + outcome_transform=outcome_transform, + ) - seed = 0 - for kernel in ( - ScaleKernel(MaternKernel(nu=2.5, ard_num_dims=2, batch_shape=Size([]))), - ): - with torch.random.fork_rng(): - torch.manual_seed(seed) - tkwargs = {"device": self.device, "dtype": torch.float64} - - base = kernel.base_kernel if isinstance(kernel, ScaleKernel) else kernel - base.lengthscale = 0.1 + 0.3 * torch.rand_like(base.lengthscale) - kernel.to(**tkwargs) - - uppers = 1 + 9 * torch.rand(base.lengthscale.shape[-1], **tkwargs) - bounds = pad(uppers.unsqueeze(0), (0, 0, 1, 0)) - - X = uppers * torch.rand(4, base.lengthscale.shape[-1], **tkwargs) - Y = 10 * kernel(X).cholesky() @ torch.randn(4, 1, **tkwargs) - if kernel.batch_shape: - Y = Y.squeeze(-1).transpose(0, 1) # n x m - - input_transform = Normalize(d=X.shape[-1], bounds=bounds) - outcome_transform = Standardize(m=Y.shape[-1]) - - # SingleTaskGP in eval mode - self.models[SingleTaskGP].append( - SingleTaskGP( - train_X=X, - train_Y=Y, - covar_module=deepcopy(kernel), - input_transform=deepcopy(input_transform), - outcome_transform=deepcopy(outcome_transform), - ) - .to(**tkwargs) - .eval() - ) - - # FixedNoiseGP in train mode - self.models[FixedNoiseGP].append( - FixedNoiseGP( - train_X=X, - train_Y=Y, - train_Yvar=0.01 * torch.rand_like(Y), - covar_module=kernel, - input_transform=input_transform, - outcome_transform=outcome_transform, - ).to(**tkwargs) - ) - - # SingleTaskVariationalGP in train mode - self.models[SingleTaskVariationalGP].append( - SingleTaskVariationalGP( - train_X=X, - train_Y=Y, - covar_module=kernel, - input_transform=input_transform, - outcome_transform=outcome_transform, - ).to(**tkwargs) - ) - - seed += 1 + # SingleTaskVariationalGP in train mode + self.variational_gp = SingleTaskVariationalGP( + train_X=X, + train_Y=Y, + covar_module=kernel, + input_transform=input_transform, + outcome_transform=outcome_transform, + ).to(**tkwargs) def test_draw_matheron_paths(self): - for seed, models in enumerate(self.models.values()): - for model, sample_shape in product(models, [Size([1024]), Size([32, 32])]): - with torch.random.fork_rng(): - torch.random.manual_seed(seed) - paths = draw_matheron_paths(model=model, sample_shape=sample_shape) - self.assertIsInstance(paths, MatheronPath) - self._test_draw_matheron_paths(model, paths, sample_shape) + for seed, model in enumerate( + (self.inferred_noise_gp, self.observed_noise_gp, self.variational_gp) + ): + for sample_shape in [Size([1024]), Size([32, 32])]: + torch.random.manual_seed(seed) + paths = draw_matheron_paths(model=model, sample_shape=sample_shape) + self.assertIsInstance(paths, MatheronPath) + self._test_draw_matheron_paths(model, paths, sample_shape) with self.subTest("test_model_list"): - model_list = ModelListGP( - self.models[SingleTaskGP][0], self.models[FixedNoiseGP][0] - ) + model_list = ModelListGP(self.inferred_noise_gp, self.observed_noise_gp) path_list = draw_matheron_paths(model_list, sample_shape=sample_shape) (train_X,) = get_train_inputs(model_list.models[0], transformed=False) X = torch.zeros( diff --git a/test/sampling/pathwise/test_prior_samplers.py b/test/sampling/pathwise/test_prior_samplers.py index 871c23eb76..5c562eb024 100644 --- a/test/sampling/pathwise/test_prior_samplers.py +++ b/test/sampling/pathwise/test_prior_samplers.py @@ -12,12 +12,7 @@ from unittest.mock import MagicMock import torch -from botorch.models import ( - FixedNoiseGP, - ModelListGP, - SingleTaskGP, - SingleTaskVariationalGP, -) +from botorch.models import ModelListGP, SingleTaskGP, SingleTaskVariationalGP from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize from botorch.sampling.pathwise import ( @@ -64,8 +59,8 @@ def setUp(self) -> None: input_transform = Normalize(d=X.shape[-1], bounds=bounds) outcome_transform = Standardize(m=Y.shape[-1]) - # SingleTaskGP in eval mode - self.models[SingleTaskGP].append( + # SingleTaskGP w/ inferred noise in eval mode + self.models["inferred"].append( SingleTaskGP( train_X=X, train_Y=Y, @@ -77,9 +72,9 @@ def setUp(self) -> None: .eval() ) - # FixedNoiseGP in train mode - self.models[FixedNoiseGP].append( - FixedNoiseGP( + # SingleTaskGP w/ observed noise in train mode + self.models["observed"].append( + SingleTaskGP( train_X=X, train_Y=Y, train_Yvar=0.01 * torch.rand_like(Y), @@ -92,7 +87,7 @@ def setUp(self) -> None: # SingleTaskVariationalGP in train mode # When batched, uses a multitask format which break the tests below if not kernel.batch_shape: - self.models[SingleTaskVariationalGP].append( + self.models["variational"].append( SingleTaskVariationalGP( train_X=X, train_Y=Y, @@ -119,7 +114,7 @@ def test_draw_kernel_feature_paths(self): with self.subTest("test_model_list"): model_list = ModelListGP( - self.models[SingleTaskGP][0], self.models[FixedNoiseGP][0] + self.models["inferred"][0], self.models["observed"][0] ) path_list = draw_kernel_feature_paths( model=model_list, @@ -136,7 +131,7 @@ def test_draw_kernel_feature_paths(self): self.assertEqual(len(sample_list), len(path_list.paths)) with self.subTest("test_initialization"): - model = self.models[SingleTaskGP][0] + model = self.models["inferred"][0] sample_shape = torch.Size([16]) expected_weight_shape = ( sample_shape + model.covar_module.batch_shape + (self.num_features,) diff --git a/test/sampling/pathwise/test_update_strategies.py b/test/sampling/pathwise/test_update_strategies.py index 71f2d717ab..1e85bac474 100644 --- a/test/sampling/pathwise/test_update_strategies.py +++ b/test/sampling/pathwise/test_update_strategies.py @@ -12,7 +12,7 @@ from unittest.mock import patch import torch -from botorch.models import FixedNoiseGP, SingleTaskGP, SingleTaskVariationalGP +from botorch.models import SingleTaskGP, SingleTaskVariationalGP from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize from botorch.sampling.pathwise import ( @@ -61,8 +61,8 @@ def setUp(self) -> None: input_transform = Normalize(d=X.shape[-1], bounds=bounds) outcome_transform = Standardize(m=Y.shape[-1]) - # SingleTaskGP in eval mode - self.models[SingleTaskGP].append( + # SingleTaskGP w/ inferred noise in eval mode + self.models["inferred"].append( SingleTaskGP( train_X=X, train_Y=Y, @@ -74,9 +74,9 @@ def setUp(self) -> None: .eval() ) - # FixedNoiseGP in train mode - self.models[FixedNoiseGP].append( - FixedNoiseGP( + # SingleTaskGP w/ observed noise in train mode + self.models["observed"].append( + SingleTaskGP( train_X=X, train_Y=Y, train_Yvar=0.01 * torch.rand_like(Y), @@ -89,7 +89,7 @@ def setUp(self) -> None: # SingleTaskVariationalGP in train mode # When batched, uses a multitask format which break the tests below if not kernel.batch_shape: - self.models[SingleTaskVariationalGP].append( + self.models["variational"].append( SingleTaskVariationalGP( train_X=X, train_Y=Y, diff --git a/test/test_cross_validation.py b/test/test_cross_validation.py index 9c6308c2b1..26c60d8e31 100644 --- a/test/test_cross_validation.py +++ b/test/test_cross_validation.py @@ -10,7 +10,7 @@ import torch from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds from botorch.exceptions.warnings import OptimizationWarning -from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP +from botorch.models.gp_regression import SingleTaskGP from botorch.utils.testing import _get_random_data, BotorchTestCase from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood @@ -57,7 +57,7 @@ def test_single_task_batch_cv(self): self.assertEqual(cv_results.posterior.mean.shape, expected_shape) self.assertEqual(cv_results.observed_Y.shape, expected_shape) - # Test FixedNoiseGP + # Test with noise observations noisy_cv_folds = gen_loo_cv_folds( train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar ) @@ -71,7 +71,7 @@ def test_single_task_batch_cv(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=OptimizationWarning) cv_results = batch_cross_validation( - model_cls=FixedNoiseGP, + model_cls=SingleTaskGP, mll_cls=ExactMarginalLogLikelihood, cv_folds=noisy_cv_folds, fit_args={"optimizer_kwargs": {"options": {"maxiter": 1}}}, diff --git a/test/test_end_to_end.py b/test/test_end_to_end.py index 51b45f75cf..179bd8917e 100644 --- a/test/test_end_to_end.py +++ b/test/test_end_to_end.py @@ -11,7 +11,7 @@ from botorch.acquisition import ExpectedImprovement, qExpectedImprovement from botorch.exceptions.warnings import OptimizationWarning from botorch.fit import fit_gpytorch_mll -from botorch.models import FixedNoiseGP, SingleTaskGP +from botorch.models.gp_regression import SingleTaskGP from botorch.optim import optimize_acqf from botorch.utils.testing import BotorchTestCase from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood @@ -56,7 +56,7 @@ def _setUp(self, double=False): optimizer_kwargs={"options": {"maxiter": 5}}, max_attempts=1, ) - model_fn = FixedNoiseGP( + model_fn = SingleTaskGP( self.train_x, self.train_y, self.train_yvar.expand_as(self.train_y) ) self.model_fn = model_fn.to(device=self.device, dtype=dtype) diff --git a/test/utils/test_gp_sampling.py b/test/utils/test_gp_sampling.py index 30a364fd15..9d17453aad 100644 --- a/test/utils/test_gp_sampling.py +++ b/test/utils/test_gp_sampling.py @@ -12,7 +12,7 @@ from botorch.models.converter import batched_to_model_list from botorch.models.deterministic import DeterministicModel from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP -from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP +from botorch.models.gp_regression import SingleTaskGP from botorch.models.model import ModelList from botorch.models.multitask import MultiTaskGP from botorch.models.transforms.input import Normalize @@ -652,7 +652,7 @@ def test_get_gp_samples(self): def test_with_fixed_noise(self): for n_samples in (1, 20): gp_samples = get_gp_samples( - model=FixedNoiseGP( + model=SingleTaskGP( torch.rand(5, 3, dtype=torch.double), torch.randn(5, 1, dtype=torch.double), torch.rand(5, 1, dtype=torch.double) * 0.1, diff --git a/test/utils/test_sampling.py b/test/utils/test_sampling.py index 4ef8643e96..8ef495df1b 100644 --- a/test/utils/test_sampling.py +++ b/test/utils/test_sampling.py @@ -14,7 +14,7 @@ import numpy as np import torch from botorch.exceptions.errors import BotorchError -from botorch.models import FixedNoiseGP +from botorch.models.gp_regression import SingleTaskGP from botorch.sampling.pathwise import draw_matheron_paths from botorch.utils.sampling import ( _convert_bounds_to_inequality_constraints, @@ -514,7 +514,7 @@ def test_optimize_posterior_samples(self): # having a noiseless model all but guarantees that the found optima # will be better than the observations - model = FixedNoiseGP(X, Y, torch.full_like(Y, eps)) + model = SingleTaskGP(X, Y, torch.full_like(Y, eps)) paths = draw_matheron_paths( model=model, sample_shape=torch.Size([num_optima]) )