Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added posterior_transform to posterior method in ApproximateGPyTorchModel #2531

Closed
16 changes: 13 additions & 3 deletions botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@

import copy
import warnings

from typing import Optional, Union

import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.warnings import UserInputWarning
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.transforms.input import InputTransform
Expand Down Expand Up @@ -146,8 +146,16 @@ def train(self, mode: bool = True) -> Self:
return Module.train(self, mode=mode)

def posterior(
self, X, output_indices=None, observation_noise=False, *args, **kwargs
self,
X,
output_indices: Optional[list[int]] = None,
observation_noise: bool = False,
posterior_transform: Optional[PosteriorTransform] = None,
) -> GPyTorchPosterior:
if output_indices is not None:
raise NotImplementedError( # pragma: no cover
f"{self.__class__.__name__}.posterior does not support output indices."
)
self.eval() # make sure model is in eval mode

# input transforms are applied at `posterior` in `eval` mode, and at
Expand All @@ -161,11 +169,13 @@ def posterior(
X = X.unsqueeze(-3).repeat(*[1] * (X_ndim - 2), self.num_outputs, 1, 1)
dist = self.model(X)
if observation_noise:
dist = self.likelihood(dist, *args, **kwargs)
dist = self.likelihood(dist)

posterior = GPyTorchPosterior(distribution=dist)
if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
posterior = posterior_transform(posterior)
SaiAakash marked this conversation as resolved.
Show resolved Hide resolved
return posterior

def forward(self, X) -> MultivariateNormal:
Expand Down
11 changes: 11 additions & 0 deletions test/models/test_approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings

import torch
from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.exceptions.warnings import UserInputWarning
from botorch.fit import fit_gpytorch_mll
from botorch.models.approximate_gp import (
Expand Down Expand Up @@ -103,6 +104,16 @@ def test_posterior(self):
# test batch_shape property
self.assertEqual(model.batch_shape, tx.shape[:-2])

# Test that checks if posterior_transform is correctly applied
[tx1, ty1, test1] = all_tests["non_batched_mo"]
model1 = SingleTaskVariationalGP(tx1, ty1, inducing_points=tx1)
posterior_transform = ScalarizedPosteriorTransform(
weights=torch.tensor([1.0, 1.0])
Balandat marked this conversation as resolved.
Show resolved Hide resolved
)
posterior1 = model1.posterior(test1, posterior_transform=posterior_transform)
self.assertIsInstance(posterior1, GPyTorchPosterior)
self.assertEqual(posterior1.mean.shape[1], 1)

def test_variational_setUp(self):
for dtype in [torch.float, torch.double]:
train_X = torch.rand(10, 1, device=self.device, dtype=dtype)
Expand Down
Loading