Skip to content

Commit

Permalink
Remove Fully Bayesian logic in low_rank (#1773)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1773

Remove the fully Bayesian specialty logic (previously necessary to deal with non-standard batch ranges).

Reviewed By: sdaulton

Differential Revision: D44634780

fbshipit-source-id: 191bce4ca9cb67325ae66cea2c5a0ddb159c2f4a
  • Loading branch information
Balandat authored and facebook-github-bot committed Apr 4, 2023
1 parent 1176a38 commit ae7ceb9
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions botorch/utils/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
from botorch.exceptions.errors import BotorchError
from botorch.posteriors.base_samples import _reshape_base_samples_non_interleaved
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions.multitask_multivariate_normal import (
MultitaskMultivariateNormal,
Expand Down Expand Up @@ -62,12 +61,9 @@ def _reshape_base_samples(
mvn = posterior.distribution
loc = mvn.loc
peshape = posterior._extended_shape()
is_fully_b = int(isinstance(posterior, FullyBayesianPosterior))
base_samples = base_samples.view(
sample_shape
+ torch.Size([1 for _ in range(loc.ndim - 1 - is_fully_b)])
+ peshape[-2 - is_fully_b :]
).expand(sample_shape + loc.shape[: -1 - is_fully_b] + peshape[-2 - is_fully_b :])
sample_shape + torch.Size([1] * (loc.ndim - 1)) + peshape[-2:]
).expand(sample_shape + loc.shape[:-1] + peshape[-2:])
if posterior._is_mt:
base_samples = _reshape_base_samples_non_interleaved(
mvn=posterior.distribution,
Expand Down

0 comments on commit ae7ceb9

Please sign in to comment.