Skip to content

Commit

Permalink
raise error for MTGP in batch_cross_validation (pytorch#2554)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2554

see title.

pytorch#2553

Reviewed By: saitcakmak, Balandat

Differential Revision: D63464560
  • Loading branch information
sdaulton authored and facebook-github-bot committed Sep 26, 2024
1 parent ad9978f commit 4c8b32a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
6 changes: 6 additions & 0 deletions botorch/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from typing import Any, NamedTuple, Optional

import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.fit import fit_gpytorch_mll
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.multitask import MultiTaskGP
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor
Expand Down Expand Up @@ -166,6 +168,10 @@ def batch_cross_validation(
... },
... )
"""
if issubclass(model_cls, MultiTaskGP):
raise UnsupportedError(
"Multi-task GPs are not currently supported by `batch_cross_validation`."
)
model_init_kws = model_init_kwargs if model_init_kwargs is not None else {}
if cv_folds.train_Yvar is not None:
model_init_kws["train_Yvar"] = cv_folds.train_Yvar
Expand Down
17 changes: 17 additions & 0 deletions test/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import torch
from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.utils.testing import _get_random_data, BotorchTestCase
Expand Down Expand Up @@ -107,3 +109,18 @@ def test_single_task_batch_cv(self) -> None:
cv_results.posterior.mean.device.type, self.device.type
)
self.assertIs(cv_results.posterior.mean.dtype, dtype)

def test_mtgp(self):
train_X, train_Y = _get_random_data(
batch_shape=torch.Size(), m=1, n=3, device=self.device
)
cv_folds = gen_loo_cv_folds(train_X=train_X, train_Y=train_Y)
with self.assertRaisesRegex(
UnsupportedError, "Multi-task GPs are not currently supported."
):
batch_cross_validation(
model_cls=MultiTaskGP,
mll_cls=ExactMarginalLogLikelihood,
cv_folds=cv_folds,
fit_args={"optimizer_kwargs": {"options": {"maxiter": 1}}},
)

0 comments on commit 4c8b32a

Please sign in to comment.