Skip to content

Commit

Permalink
Apply input transforms when computing MLL in model closures
Browse files Browse the repository at this point in the history
Summary:
During model training, the input transforms are applied in `model.forward`. While evaluating the model closures, we pass in the train inputs to the `mll`, which passed them down to the `likelihood`. If we don't transform the inputs before passing them into `mll`, we end up evaluating `model.forward` and `likelihood` using different inputs.

This is not an issue during the `posterior` evaluation, since the transforms are applied in `model.posterior` before being passed to `model.__call__` and `likelihood`.

This diff updates the model closures to transform the inputs before passing them into `mll`.

Fixes #2515

Differential Revision: D62497392
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Sep 11, 2024
1 parent 4d49bf7 commit 83fc1c2
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 32 deletions.
27 changes: 21 additions & 6 deletions botorch/optim/closures/model_closures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from __future__ import annotations

from collections.abc import Sequence

from itertools import chain, repeat
from types import NoneType
from typing import Any, Callable, Optional
Expand Down Expand Up @@ -174,9 +173,17 @@ def _get_loss_closure_exact_internal(
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""

def closure(**kwargs: Any) -> Tensor:
model_output = mll.model(*mll.model.train_inputs)
model = mll.model
# The inputs will get transformed in forward here.
model_output = model(*model.train_inputs)
log_likelihood = mll(
model_output, mll.model.train_targets, *mll.model.train_inputs, **kwargs
model_output,
model.train_targets,
# During model training, the model inputs get transformed in the forward pass.
# The train_inputs property is not transformed yet, so we need to transform
# it before passing it to the likelihood for consistency.
*(model.transform_inputs(X=t_in) for t_in in model.train_inputs),
**kwargs,
)
return -log_likelihood

Expand All @@ -190,11 +197,19 @@ def _get_loss_closure_sum_internal(
r"""SumMarginalLogLikelihood loss closure with internally managed data."""

def closure(**kwargs: Any) -> Tensor:
model_output = mll.model(*mll.model.train_inputs)
model = mll.model
# The inputs will get transformed in forward here.
model_output = model(*model.train_inputs)
log_likelihood = mll(
model_output,
mll.model.train_targets,
*map(list, mll.model.train_inputs),
model.train_targets,
# During model training, the model inputs get transformed in the forward pass.
# The train_inputs property is not transformed yet, so we need to transform
# it before passing it to the likelihood for consistency.
*(
(model.transform_inputs(X=t_in) for t_in in sub_t_in)
for sub_t_in in model.train_inputs
),
**kwargs,
)
return -log_likelihood
Expand Down
119 changes: 93 additions & 26 deletions test/optim/closures/test_model_closures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,67 @@
)
from botorch.utils.testing import BotorchTestCase
from gpytorch import settings as gpytorch_settings
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood, SumMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.module import Module
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset


# Mock wrapping the __call__ directly is leading to errors like
# TypeError: super(type, obj): obj must be an instance or subtype of type
# so, doing this manually here.
class WrapperLikelihood(GaussianLikelihood):
def __init__(self, base_likelihood: GaussianLikelihood):
Module.__init__(self)
self.base_likelihood = base_likelihood
self.call_args = []

def __call__(self, *args, **kwargs):
# Store the train inputs arg for testing.
self.call_args.append(args[1])
return self.base_likelihood(*args, **kwargs)


def _get_mlls(
device: torch.device, wrap_likelihood: bool = False
) -> tuple[Tensor, list[MarginalLogLikelihood]]:
"""Returns the train X, along two MLLs: one for a SingleTaskGP and
one for a ModelListGP.
Args:
device: The device to use.
wrap_likelihood: If True, wrap the likelihood in a WrapperLikelihood.
This is useful for comparing call args later.
"""
with torch.random.fork_rng():
torch.manual_seed(0)
# Inputs are not in the unit cube to ensure input transform is applied.
train_X = torch.linspace(0, 5, 10).unsqueeze(-1)
train_Y = torch.sin((2 * pi) * train_X)
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)
mlls = []
model = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
input_transform=Normalize(d=1),
outcome_transform=Standardize(m=1),
)
if wrap_likelihood:
model.likelihood = WrapperLikelihood(model.likelihood)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
mlls.append(mll.to(device=device, dtype=torch.double))

model = ModelListGP(model, model)
mll = SumMarginalLogLikelihood(model.likelihood, model)
mlls.append(mll.to(device=device, dtype=torch.double))
return train_X.to(device=device, dtype=torch.double), mlls


class TestLossClosures(BotorchTestCase):
def setUp(self):
super().setUp()
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
train_Y = torch.sin((2 * pi) * train_X)
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)

self.mlls = {}
model = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
input_transform=Normalize(d=1),
outcome_transform=Standardize(m=1),
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device)

model = ModelListGP(model, model)
mll = SumMarginalLogLikelihood(model.likelihood, model)
self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device)

def test_main(self):
for mll in self.mlls.values():
def test_main(self) -> None:
for mll in _get_mlls(device=self.device)[1]:
out = mll.model(*mll.model.train_inputs)
loss = -mll(out, mll.model.train_targets).sum()
loss.backward()
Expand All @@ -63,8 +95,8 @@ def test_main(self):
self.assertTrue(loss.equal(_loss))
self.assertTrue(all(a.equal(b) for a, b in zip_longest(grads, _grads)))

def test_data_loader(self):
for mll in self.mlls.values():
def test_data_loader(self) -> None:
for mll in _get_mlls(device=self.device)[1]:
if type(mll) is not ExactMarginalLogLikelihood:
continue

Expand All @@ -86,3 +118,38 @@ def test_data_loader(self):
closure = get_loss_closure_with_grads(mll, params, data_loader=loader)
with self.assertRaisesRegex(TypeError, "Expected .* a batch of tensors"):
closure()

def test_with_input_transforms(self) -> None:
# This test reproduces the bug reported in issue #2515.
train_X, mlls = _get_mlls(device=self.device, wrap_likelihood=True)
for mll in mlls:
if isinstance(mll, SumMarginalLogLikelihood):
# The likelihood is called twice here since it is the same
# likelihood in both child models.
likelihood = mll.model.models[0].likelihood
expected_calls1 = 2 # In the closure call.
expected_calls2 = 6 # Closure + posterior calls.
else:
likelihood = mll.model.likelihood
expected_calls1 = 1 # In the closure call.
expected_calls2 = 4 # Closure + posterior calls.
likelihood.call_args = [] # reset since it is shared between the models.
params = {n: p for n, p in mll.named_parameters() if p.requires_grad}
# Evaluate the closure to mimic the model fitting process.
mll.train()
closure = get_loss_closure_with_grads(mll, params)
closure()
self.assertEqual(len(likelihood.call_args), expected_calls1)
# Call the model posterior to reproduce post-fitting usage.
mll.model.posterior(train_X, observation_noise=True)
# Compare the call args to ensure they're all the same.
# Likelihood is called twice on model(X) and once for adding the noise.
self.assertEqual(len(likelihood.call_args), expected_calls2)
arg0 = likelihood.call_args[0]
for i in range(1, expected_calls2):
argi = likelihood.call_args[i]
# The arg may be a tensor or a single element list of the tensor.
self.assertAllClose(
arg0 if isinstance(arg0, Tensor) else arg0[0],
argi if isinstance(argi, Tensor) else argi[0],
)

0 comments on commit 83fc1c2

Please sign in to comment.