Skip to content

Commit

Permalink
Fix posterior method in BatchedMultiOutputGPyTorchModel for tracing…
Browse files Browse the repository at this point in the history
… JIT (pytorch#2592)

Summary:
## Motivation

Fixes pytorch#2591. Generates the MTMVN for the independent task case slightly differently when jit traced.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#2592

Test Plan:
A unit test `test_posterior_in_trace_mode` has been added to test_gpytorch.py

## Related PRs

NA

Reviewed By: saitcakmak, Balandat

Differential Revision: D64903356

Pulled By: sdaulton

fbshipit-source-id: 32fa2f108e99683d92344e31123a6bd07cc4113b
  • Loading branch information
SaiAakash authored and facebook-github-bot committed Oct 24, 2024
1 parent ccf278a commit e7539db
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
25 changes: 15 additions & 10 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,17 +446,22 @@ def posterior(
mvn = self(X)
mvn = self._apply_noise(X=X, mvn=mvn, observation_noise=observation_noise)
if self._num_outputs > 1:
mean_x = mvn.mean
covar_x = mvn.lazy_covariance_matrix
output_indices = output_indices or range(self._num_outputs)
mvns = [
MultivariateNormal(
mean_x.select(dim=output_dim_idx, index=t),
covar_x[(slice(None),) * output_dim_idx + (t,)],
if torch.jit.is_tracing():
mvn = MultitaskMultivariateNormal.from_batch_mvn(
mvn, task_dim=output_dim_idx
)
for t in output_indices
]
mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
else:
mean_x = mvn.mean
covar_x = mvn.lazy_covariance_matrix
output_indices = output_indices or range(self._num_outputs)
mvns = [
MultivariateNormal(
mean_x.select(dim=output_dim_idx, index=t),
covar_x[(slice(None),) * output_dim_idx + (t,)],
)
for t in output_indices
]
mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)

posterior = GPyTorchPosterior(distribution=mvn)
if hasattr(self, "outcome_transform"):
Expand Down
27 changes: 27 additions & 0 deletions test/models/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP, IndependentModelList
from gpytorch.settings import trace_mode
from torch import Tensor


Expand Down Expand Up @@ -410,6 +411,32 @@ def test_posterior_transform(self):
post = model.posterior(torch.rand(3, 2, **tkwargs), posterior_transform=post_tf)
self.assertTrue(torch.equal(post.mean, torch.zeros(3, 1, **tkwargs)))

def test_posterior_in_trace_mode(self):
tkwargs = {"device": self.device, "dtype": torch.double}
train_X = torch.rand(5, 1, **tkwargs)
train_Y = torch.cat([torch.sin(train_X), torch.cos(train_X)], dim=-1)
model = SimpleBatchedMultiOutputGPyTorchModel(train_X, train_Y)

class MeanVarModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, x):
# get the model posterior
posterior = self.model.posterior(x, observation_noise=True)
mean = posterior.mean.detach()
std = posterior.variance.sqrt().detach()
return mean, std

wrapped_model = MeanVarModelWrapper(model)
with torch.no_grad(), trace_mode():
X_test = torch.rand(3, 1, **tkwargs)
wrapped_model(X_test) # Compute caches
traced_model = torch.jit.trace(wrapped_model, X_test)
mean, std = traced_model(X_test)
self.assertEqual(mean.shape, torch.Size([3, 2]))


class TestModelListGPyTorchModel(BotorchTestCase):
def test_model_list_gpytorch_model(self):
Expand Down

0 comments on commit e7539db

Please sign in to comment.