From 4c8b32adb816dc25be5d36f84dc34becc1298b67 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Thu, 26 Sep 2024 09:38:29 -0700 Subject: [PATCH] raise error for MTGP in batch_cross_validation (#2554) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2554 see title. https://github.com/pytorch/botorch/issues/2553 Reviewed By: saitcakmak, Balandat Differential Revision: D63464560 --- botorch/cross_validation.py | 6 ++++++ test/test_cross_validation.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/botorch/cross_validation.py b/botorch/cross_validation.py index 2fb0831797..5478a1a539 100644 --- a/botorch/cross_validation.py +++ b/botorch/cross_validation.py @@ -13,8 +13,10 @@ from typing import Any, NamedTuple, Optional import torch +from botorch.exceptions.errors import UnsupportedError from botorch.fit import fit_gpytorch_mll from botorch.models.gpytorch import GPyTorchModel +from botorch.models.multitask import MultiTaskGP from botorch.posteriors.gpytorch import GPyTorchPosterior from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from torch import Tensor @@ -166,6 +168,10 @@ def batch_cross_validation( ... }, ... ) """ + if issubclass(model_cls, MultiTaskGP): + raise UnsupportedError( + "Multi-task GPs are not currently supported by `batch_cross_validation`." + ) model_init_kws = model_init_kwargs if model_init_kwargs is not None else {} if cv_folds.train_Yvar is not None: model_init_kws["train_Yvar"] = cv_folds.train_Yvar diff --git a/test/test_cross_validation.py b/test/test_cross_validation.py index 0a9d65995f..518fdd7afb 100644 --- a/test/test_cross_validation.py +++ b/test/test_cross_validation.py @@ -9,8 +9,10 @@ import torch from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds +from botorch.exceptions.errors import UnsupportedError from botorch.exceptions.warnings import OptimizationWarning from botorch.models.gp_regression import SingleTaskGP +from botorch.models.multitask import MultiTaskGP from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize from botorch.utils.testing import _get_random_data, BotorchTestCase @@ -107,3 +109,18 @@ def test_single_task_batch_cv(self) -> None: cv_results.posterior.mean.device.type, self.device.type ) self.assertIs(cv_results.posterior.mean.dtype, dtype) + + def test_mtgp(self): + train_X, train_Y = _get_random_data( + batch_shape=torch.Size(), m=1, n=3, device=self.device + ) + cv_folds = gen_loo_cv_folds(train_X=train_X, train_Y=train_Y) + with self.assertRaisesRegex( + UnsupportedError, "Multi-task GPs are not currently supported." + ): + batch_cross_validation( + model_cls=MultiTaskGP, + mll_cls=ExactMarginalLogLikelihood, + cv_folds=cv_folds, + fit_args={"optimizer_kwargs": {"options": {"maxiter": 1}}}, + )