Skip to content

Commit

Permalink
Patch codecov for SaasFullyBayesianSingleTaskGP & FullyBayesianPosterior
Browse files Browse the repository at this point in the history
Summary: --

Reviewed By: Balandat, SebastianAment

Differential Revision: D54278602

fbshipit-source-id: 259fc73e03a68a13507b1574c50ca21fca5d86b2
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 27, 2024
1 parent bd76804 commit fbb460d
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@
SaasPyroModel,
)
from botorch.models.transforms import Normalize, Standardize
from botorch.posteriors.fully_bayesian import batched_bisect, GaussianMixturePosterior
from botorch.posteriors.fully_bayesian import (
batched_bisect,
FullyBayesianPosterior,
GaussianMixturePosterior,
)
from botorch.sampling.get_sampler import get_sampler
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
Expand Down Expand Up @@ -752,6 +756,15 @@ def test_condition_on_observation(self):
torch.Size([num_models, num_train + num_cond, num_dims]),
)

# With batch size only on Y.
cond_model = model.condition_on_observations(
cond_X_nobatch, cond_Y, noise=cond_Yvar
)
self.assertEqual(
cond_model.train_inputs[0].shape,
torch.Size([num_models, num_train + num_cond, num_dims]),
)

# test repeated conditining
repeat_cond_X = cond_X + 5
repeat_cond_model = cond_model.condition_on_observations(
Expand Down Expand Up @@ -815,6 +828,17 @@ def f(x):
dist.cdf(x), q * torch.ones(1, 5, 1, **tkwargs), atol=1e-4
)

def test_deprecated_posterior(self) -> None:
mean = torch.randn(1, 5)
variance = torch.rand(1, 5)
covar = torch.diag_embed(variance)
mvn = MultivariateNormal(mean, to_linear_operator(covar))
with self.assertWarnsRegex(
DeprecationWarning, "`FullyBayesianPosterior` is marked for deprecation"
):
posterior = FullyBayesianPosterior(distribution=mvn)
self.assertIsInstance(posterior, GaussianMixturePosterior)


class TestPyroCatchNumericalErrors(BotorchTestCase):
def tearDown(self):
Expand Down

0 comments on commit fbb460d

Please sign in to comment.