diff --git a/botorch/optim/closures/model_closures.py b/botorch/optim/closures/model_closures.py index e511d8b958..b740a0d577 100644 --- a/botorch/optim/closures/model_closures.py +++ b/botorch/optim/closures/model_closures.py @@ -9,7 +9,6 @@ from __future__ import annotations from collections.abc import Sequence - from itertools import chain, repeat from types import NoneType from typing import Any, Callable, Optional @@ -174,9 +173,17 @@ def _get_loss_closure_exact_internal( r"""ExactMarginalLogLikelihood loss closure with internally managed data.""" def closure(**kwargs: Any) -> Tensor: - model_output = mll.model(*mll.model.train_inputs) + model = mll.model + # The inputs will get transformed in forward here. + model_output = model(*model.train_inputs) log_likelihood = mll( - model_output, mll.model.train_targets, *mll.model.train_inputs, **kwargs + model_output, + model.train_targets, + # During model training, the model inputs get transformed in the forward pass. + # The train_inputs property is not transformed yet, so we need to transform + # it before passing it to the likelihood for consistency. + *(model.transform_inputs(X=t_in) for t_in in model.train_inputs), + **kwargs, ) return -log_likelihood @@ -190,11 +197,19 @@ def _get_loss_closure_sum_internal( r"""SumMarginalLogLikelihood loss closure with internally managed data.""" def closure(**kwargs: Any) -> Tensor: - model_output = mll.model(*mll.model.train_inputs) + model = mll.model + # The inputs will get transformed in forward here. + model_output = model(*model.train_inputs) log_likelihood = mll( model_output, - mll.model.train_targets, - *map(list, mll.model.train_inputs), + model.train_targets, + # During model training, the model inputs get transformed in the forward pass. + # The train_inputs property is not transformed yet, so we need to transform + # it before passing it to the likelihood for consistency. + *( + (model.transform_inputs(X=t_in) for t_in in sub_t_in) + for sub_t_in in model.train_inputs + ), **kwargs, ) return -log_likelihood diff --git a/test/optim/closures/test_model_closures.py b/test/optim/closures/test_model_closures.py index 0d483bd819..ebbe8028e1 100644 --- a/test/optim/closures/test_model_closures.py +++ b/test/optim/closures/test_model_closures.py @@ -17,35 +17,67 @@ ) from botorch.utils.testing import BotorchTestCase from gpytorch import settings as gpytorch_settings +from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood from gpytorch.mlls import ExactMarginalLogLikelihood, SumMarginalLogLikelihood +from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood +from gpytorch.module import Module +from torch import Tensor from torch.utils.data import DataLoader, TensorDataset +# Mock wrapping the __call__ directly is leading to errors like +# TypeError: super(type, obj): obj must be an instance or subtype of type +# so, doing this manually here. +class WrapperLikelihood(GaussianLikelihood): + def __init__(self, base_likelihood: GaussianLikelihood): + Module.__init__(self) + self.base_likelihood = base_likelihood + self.call_args = [] + + def __call__(self, *args, **kwargs): + # Store the train inputs arg for testing. + self.call_args.append(args[1]) + return self.base_likelihood(*args, **kwargs) + + +def _get_mlls( + device: torch.device, wrap_likelihood: bool = False +) -> tuple[Tensor, list[MarginalLogLikelihood]]: + """Returns the train X, along two MLLs: one for a SingleTaskGP and + one for a ModelListGP. + + Args: + device: The device to use. + wrap_likelihood: If True, wrap the likelihood in a WrapperLikelihood. + This is useful for comparing call args later. + """ + with torch.random.fork_rng(): + torch.manual_seed(0) + # Inputs are not in the unit cube to ensure input transform is applied. + train_X = torch.linspace(0, 5, 10).unsqueeze(-1) + train_Y = torch.sin((2 * pi) * train_X) + train_Y = train_Y + 0.1 * torch.randn_like(train_Y) + mlls = [] + model = SingleTaskGP( + train_X=train_X, + train_Y=train_Y, + input_transform=Normalize(d=1), + outcome_transform=Standardize(m=1), + ) + if wrap_likelihood: + model.likelihood = WrapperLikelihood(model.likelihood) + mll = ExactMarginalLogLikelihood(model.likelihood, model) + mlls.append(mll.to(device=device, dtype=torch.double)) + + model = ModelListGP(model, model) + mll = SumMarginalLogLikelihood(model.likelihood, model) + mlls.append(mll.to(device=device, dtype=torch.double)) + return train_X.to(device=device, dtype=torch.double), mlls + + class TestLossClosures(BotorchTestCase): - def setUp(self): - super().setUp() - with torch.random.fork_rng(): - torch.manual_seed(0) - train_X = torch.linspace(0, 1, 10).unsqueeze(-1) - train_Y = torch.sin((2 * pi) * train_X) - train_Y = train_Y + 0.1 * torch.randn_like(train_Y) - - self.mlls = {} - model = SingleTaskGP( - train_X=train_X, - train_Y=train_Y, - input_transform=Normalize(d=1), - outcome_transform=Standardize(m=1), - ) - mll = ExactMarginalLogLikelihood(model.likelihood, model) - self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device) - - model = ModelListGP(model, model) - mll = SumMarginalLogLikelihood(model.likelihood, model) - self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device) - - def test_main(self): - for mll in self.mlls.values(): + def test_main(self) -> None: + for mll in _get_mlls(device=self.device)[1]: out = mll.model(*mll.model.train_inputs) loss = -mll(out, mll.model.train_targets).sum() loss.backward() @@ -63,8 +95,8 @@ def test_main(self): self.assertTrue(loss.equal(_loss)) self.assertTrue(all(a.equal(b) for a, b in zip_longest(grads, _grads))) - def test_data_loader(self): - for mll in self.mlls.values(): + def test_data_loader(self) -> None: + for mll in _get_mlls(device=self.device)[1]: if type(mll) is not ExactMarginalLogLikelihood: continue @@ -86,3 +118,38 @@ def test_data_loader(self): closure = get_loss_closure_with_grads(mll, params, data_loader=loader) with self.assertRaisesRegex(TypeError, "Expected .* a batch of tensors"): closure() + + def test_with_input_transforms(self) -> None: + # This test reproduces the bug reported in issue #2515. + train_X, mlls = _get_mlls(device=self.device, wrap_likelihood=True) + for mll in mlls: + if isinstance(mll, SumMarginalLogLikelihood): + # The likelihood is called twice here since it is the same + # likelihood in both child models. + likelihood = mll.model.models[0].likelihood + expected_calls1 = 2 # In the closure call. + expected_calls2 = 6 # Closure + posterior calls. + else: + likelihood = mll.model.likelihood + expected_calls1 = 1 # In the closure call. + expected_calls2 = 4 # Closure + posterior calls. + likelihood.call_args = [] # reset since it is shared between the models. + params = {n: p for n, p in mll.named_parameters() if p.requires_grad} + # Evaluate the closure to mimic the model fitting process. + mll.train() + closure = get_loss_closure_with_grads(mll, params) + closure() + self.assertEqual(len(likelihood.call_args), expected_calls1) + # Call the model posterior to reproduce post-fitting usage. + mll.model.posterior(train_X, observation_noise=True) + # Compare the call args to ensure they're all the same. + # Likelihood is called twice on model(X) and once for adding the noise. + self.assertEqual(len(likelihood.call_args), expected_calls2) + arg0 = likelihood.call_args[0] + for i in range(1, expected_calls2): + argi = likelihood.call_args[i] + # The arg may be a tensor or a single element list of the tensor. + self.assertAllClose( + arg0 if isinstance(arg0, Tensor) else arg0[0], + argi if isinstance(argi, Tensor) else argi[0], + )