Skip to content

Commit

Permalink
Support observed noise in MixedSingleTaskGP (pytorch#2054)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2054

Supporting observed noise was blocked on `SingleTaskGP` not supporting `train_Yvar`. We can easily support it after pytorch#2052

Reviewed By: esantorella

Differential Revision: D50394746

fbshipit-source-id: d830d92031b4c3a1b190ce438cdf0513130ae2b0
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 18, 2023
1 parent 6b6cc0d commit f70c8ec
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 40 deletions.
15 changes: 5 additions & 10 deletions botorch/models/gp_regression_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@

from __future__ import annotations

import warnings
from typing import Any, Callable, Dict, List, Optional

import torch
from botorch.exceptions.warnings import InputDataWarning
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.kernels.categorical import CategoricalKernel
from botorch.models.transforms.input import InputTransform
Expand Down Expand Up @@ -64,6 +62,7 @@ def __init__(
train_X: Tensor,
train_Y: Tensor,
cat_dims: List[int],
train_Yvar: Optional[Tensor] = None,
cont_kernel_factory: Optional[
Callable[[torch.Size, int, List[int]], Kernel]
] = None,
Expand All @@ -78,6 +77,8 @@ def __init__(
train_Y: A `batch_shape x n x m` tensor of training observations.
cat_dims: A list of indices corresponding to the columns of
the input `X` that should be considered categorical features.
train_Yvar: An optional `batch_shape x n x m` tensor of observed
measurement noise.
cont_kernel_factory: A method that accepts `batch_shape`, `ard_num_dims`,
and `active_dims` arguments and returns an instantiated GPyTorch
`Kernel` object to be used as the base kernel for the continuous
Expand Down Expand Up @@ -118,7 +119,7 @@ def cont_kernel_factory(
lengthscale_constraint=GreaterThan(1e-04),
)

if likelihood is None:
if likelihood is None and train_Yvar is None:
# This Gamma prior is quite close to the Horseshoe prior
min_noise = 1e-5 if train_X.dtype == torch.float else 1e-6
likelihood = GaussianLikelihood(
Expand Down Expand Up @@ -173,6 +174,7 @@ def cont_kernel_factory(
super().__init__(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
likelihood=likelihood,
covar_module=covar_module,
outcome_transform=outcome_transform,
Expand All @@ -195,13 +197,6 @@ def construct_inputs(
likelihood: Optional likelihood used to constuct the model.
"""
base_inputs = super().construct_inputs(training_data=training_data, **kwargs)
if base_inputs.pop("train_Yvar", None) is not None:
# TODO: Remove when SingleTaskGP supports optional Yvar [T162925473].
warnings.warn(
"`MixedSingleTaskGP` only supports inferred noise at the moment. "
"Ignoring the provided `train_Yvar` observations.",
InputDataWarning,
)
return {
**base_inputs,
"cat_dims": categorical_features,
Expand Down
91 changes: 61 additions & 30 deletions test/models/test_gp_regression_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings

import torch
from botorch.exceptions.warnings import InputDataWarning, OptimizationWarning
from botorch.exceptions.warnings import OptimizationWarning
from botorch.fit import fit_gpytorch_mll
from botorch.models.converter import batched_to_model_list
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
Expand All @@ -21,21 +21,24 @@
from gpytorch.kernels.kernel import AdditiveKernel, ProductKernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood

from .test_gp_regression import _get_pvar_expected


class TestMixedSingleTaskGP(BotorchTestCase):
observed_noise = False

def test_gp(self):
d = 3
bounds = torch.tensor([[-1.0] * d, [1.0] * d])
for batch_shape, m, ncat, dtype in itertools.product(
(torch.Size(), torch.Size([2])),
(1, 2),
(0, 1, 3),
(torch.float, torch.double),
for batch_shape, m, ncat, dtype, observed_noise in (
(torch.Size(), 1, 0, torch.float, False),
(torch.Size(), 2, 1, torch.double, True),
(torch.Size([2]), 2, 3, torch.double, False),
):
tkwargs = {"device": self.device, "dtype": dtype}
train_X, train_Y = _get_random_data(
Expand All @@ -62,7 +65,13 @@ def test_gp(self):
MixedSingleTaskGP(train_X, train_Y, cat_dims=cat_dims)
continue

model = MixedSingleTaskGP(train_X, train_Y, cat_dims=cat_dims)
train_Yvar = torch.full_like(train_Y, 0.1) if observed_noise else None
model = MixedSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
cat_dims=cat_dims,
train_Yvar=train_Yvar,
)
self.assertEqual(model._ignore_X_dims_scaling_check, cat_dims)
mll = ExactMarginalLogLikelihood(model.likelihood, model).to(**tkwargs)
with warnings.catch_warnings():
Expand Down Expand Up @@ -90,6 +99,10 @@ def test_gp(self):
else:
self.assertIsInstance(model.covar_module, ScaleKernel)
self.assertIsInstance(model.covar_module.base_kernel, CategoricalKernel)
if observed_noise:
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
else:
self.assertIsInstance(model.likelihood, GaussianLikelihood)

# test posterior
# test non batch evaluation
Expand Down Expand Up @@ -127,20 +140,24 @@ def test_gp(self):
with self.assertRaisesRegex(NotImplementedError, "not supported"):
batched_to_model_list(model)

def test_condition_on_observations(self):
def test_condition_on_observations__(self):
d = 3
for batch_shape, m, ncat, dtype in itertools.product(
(torch.Size(), torch.Size([2])),
(1, 2),
(1, 2),
(torch.float, torch.double),
for batch_shape, m, ncat, dtype, observed_noise in (
(torch.Size(), 2, 1, torch.float, True),
(torch.Size([2]), 1, 2, torch.double, False),
):
tkwargs = {"device": self.device, "dtype": dtype}
train_X, train_Y = _get_random_data(
batch_shape=batch_shape, m=m, d=d, **tkwargs
)
cat_dims = list(range(ncat))
model = MixedSingleTaskGP(train_X, train_Y, cat_dims=cat_dims)
train_Yvar = torch.full_like(train_Y, 0.1) if observed_noise else None
model = MixedSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
cat_dims=cat_dims,
train_Yvar=train_Yvar,
)

# evaluate model
model.posterior(torch.rand(torch.Size([4, d]), **tkwargs))
Expand All @@ -151,11 +168,16 @@ def test_condition_on_observations(self):
X_fant, Y_fant = _get_random_data(
fant_shape + batch_shape, m=m, d=d, n=3, **tkwargs
)
cm = model.condition_on_observations(X_fant, Y_fant)
additional_kwargs = (
{"noise": torch.full_like(Y_fant, 0.1)} if observed_noise else {}
)
cm = model.condition_on_observations(X_fant, Y_fant, **additional_kwargs)
# fantasize at same input points (check proper broadcasting)
additional_kwargs = (
{"noise": torch.full_like(Y_fant[0], 0.1)} if observed_noise else {}
)
cm_same_inputs = model.condition_on_observations(
X_fant[0],
Y_fant,
X_fant[0], Y_fant, **additional_kwargs
)

test_Xs = [
Expand Down Expand Up @@ -189,14 +211,20 @@ def test_condition_on_observations(self):
"train_Y": train_Y[0],
"cat_dims": cat_dims,
}
if observed_noise:
model_kwargs_non_batch["train_Yvar"] = train_Yvar[0]
model_non_batch = type(model)(**model_kwargs_non_batch)
model_non_batch.load_state_dict(state_dict_non_batch)
model_non_batch.eval()
model_non_batch.likelihood.eval()
model_non_batch.posterior(torch.rand(torch.Size([4, d]), **tkwargs))
additional_kwargs = (
{"noise": torch.full_like(Y_fant, 0.1)}
if observed_noise
else {}
)
cm_non_batch = model_non_batch.condition_on_observations(
X_fant[0][0],
Y_fant[:, 0, :],
X_fant[0][0], Y_fant[:, 0, :], **additional_kwargs
)
non_batch_posterior = cm_non_batch.posterior(test_X)
self.assertTrue(
Expand All @@ -218,18 +246,22 @@ def test_condition_on_observations(self):

def test_fantasize(self):
d = 3
for batch_shape, m, ncat, dtype in itertools.product(
(torch.Size(), torch.Size([2])),
(1, 2),
(1, 2),
(torch.float, torch.double),
for batch_shape, m, ncat, dtype, observed_noise in (
(torch.Size(), 2, 1, torch.float, True),
(torch.Size([2]), 1, 2, torch.double, False),
):
tkwargs = {"device": self.device, "dtype": dtype}
train_X, train_Y = _get_random_data(
batch_shape=batch_shape, m=m, d=d, **tkwargs
)
train_Yvar = torch.full_like(train_Y, 0.1) if observed_noise else None
cat_dims = list(range(ncat))
model = MixedSingleTaskGP(train_X, train_Y, cat_dims=cat_dims)
model = MixedSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
cat_dims=cat_dims,
train_Yvar=train_Yvar,
)

# fantasize
X_f = torch.rand(torch.Size(batch_shape + torch.Size([4, d])), **tkwargs)
Expand Down Expand Up @@ -295,10 +327,9 @@ def test_construct_inputs(self):
feature_names=[f"x{i}" for i in range(d)],
outcome_names=["y"],
)
with self.assertWarnsRegex(InputDataWarning, "train_Yvar"):
model_kwargs = MixedSingleTaskGP.construct_inputs(
training_data, categorical_features=cat_dims
)
model_kwargs = MixedSingleTaskGP.construct_inputs(
training_data, categorical_features=cat_dims
)
self.assertTrue(X.equal(model_kwargs["train_X"]))
self.assertTrue(Y.equal(model_kwargs["train_Y"]))
self.assertNotIn("train_Yvar", model_kwargs)
self.assertTrue(Y.equal(model_kwargs["train_Yvar"]))

0 comments on commit f70c8ec

Please sign in to comment.