From a4b6bdcc5a29521e6d6c19f2680a41b4d96312aa Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Tue, 30 Jul 2024 12:06:57 -0700 Subject: [PATCH] Update the default SingleTaskGP prior (#2449) Summary: X-link: https://github.com/facebook/Ax/pull/2610 Pull Request resolved: https://github.com/pytorch/botorch/pull/2449 Update of the default hyperparameter priors for the SingleTaskGP. Switch from the conventional Scale-Matern kernel with Gamma(3, 6) lengthscale prior is substituted for an RBF Kernel (without a ScaleKernel), and a change from the high-noise Gamma(1.1, 0.05) noise prior of the GaussianLikelihood to a LogNormal prior that prefers lower values. The change is made in accordance with the findings of [1]. The change is made to improve the out-of-the-box performance of the BoTorch models on high-dimensional problems. [1] Carl Hvarfner, Erik Orm Hellsten, Luigi Nardi. _Vanilla Bayesian Optimization Performs Great in High Dimensions_. ICML, 2024. Reviewed By: saitcakmak Differential Revision: D60080819 --- botorch/models/gp_regression.py | 21 ++--- botorch/utils/gp_sampling.py | 7 +- .../models/gp_regression_multisource.py | 7 +- test/acquisition/test_analytic.py | 1 + test/models/test_converter.py | 44 ++++++++--- test/models/test_deterministic.py | 2 +- test/models/test_gp_regression.py | 12 +-- test/models/test_model_list_gp_regression.py | 10 +-- test/optim/test_fit.py | 6 +- test/optim/utils/test_model_utils.py | 22 +++++- test/test_fit.py | 5 +- .../acquisition/test_multi_source.py | 11 ++- .../models/test_gp_regression_multisource.py | 24 +++++- tutorials/baxus.ipynb | 2 +- tutorials/constraint_active_search.ipynb | 2 +- .../fit_model_with_torch_optimizer.ipynb | 7 +- tutorials/ibnn_bo.ipynb | 2 +- ...tion_theoretic_acquisition_functions.ipynb | 77 ++++++++++--------- 18 files changed, 167 insertions(+), 95 deletions(-) diff --git a/botorch/models/gp_regression.py b/botorch/models/gp_regression.py index 4abf32e663..733774ed43 100644 --- a/botorch/models/gp_regression.py +++ b/botorch/models/gp_regression.py @@ -40,8 +40,8 @@ from botorch.models.transforms.outcome import Log, OutcomeTransform from botorch.models.utils import validate_input_scaling from botorch.models.utils.gpytorch_modules import ( - get_gaussian_likelihood_with_gamma_prior, - get_matern_kernel_with_gamma_prior, + get_covar_module_with_dim_scaled_prior, + get_gaussian_likelihood_with_lognormal_prior, MIN_INFERRED_NOISE_LEVEL, ) from botorch.utils.containers import BotorchContainer @@ -67,9 +67,13 @@ class SingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin): r"""A single-task exact GP model, supporting both known and inferred noise levels. - A single-task exact GP using relatively strong priors on the Kernel - hyperparameters, which work best when covariates are normalized to the unit - cube and outcomes are standardized (zero mean, unit variance). + A single-task exact GP which, by default, utilizes hyperparameter priors + from [Hvarfner2024vanilla]_. These priors designed to perform well independently of + the dimensionality of the problem. Moreover, they suggest a moderately low level of + noise. Importantly, The model works best when covariates are normalized to the unit + cube and outcomes are standardized (zero mean, unit variance). For a detailed + discussion on the hyperparameter priors, see + https://github.com/pytorch/botorch/discussions/2451. This model works in batch mode (each batch having its own hyperparameters). When the training observations include multiple outputs, this model will use @@ -174,7 +178,7 @@ def __init__( ) if likelihood is None: if train_Yvar is None: - likelihood = get_gaussian_likelihood_with_gamma_prior( + likelihood = get_gaussian_likelihood_with_lognormal_prior( batch_shape=self._aug_batch_shape ) else: @@ -190,14 +194,13 @@ def __init__( mean_module = ConstantMean(batch_shape=self._aug_batch_shape) self.mean_module = mean_module if covar_module is None: - covar_module = get_matern_kernel_with_gamma_prior( + covar_module = get_covar_module_with_dim_scaled_prior( ard_num_dims=transformed_X.shape[-1], batch_shape=self._aug_batch_shape, ) self._subset_batch_dict = { "mean_module.raw_constant": -1, - "covar_module.raw_outputscale": -1, - "covar_module.base_kernel.raw_lengthscale": -3, + "covar_module.raw_lengthscale": -3, } if train_Yvar is None: self._subset_batch_dict["likelihood.noise_covar.raw_noise"] = -2 diff --git a/botorch/utils/gp_sampling.py b/botorch/utils/gp_sampling.py index 008be0d2b1..c5ab09b981 100644 --- a/botorch/utils/gp_sampling.py +++ b/botorch/utils/gp_sampling.py @@ -143,10 +143,9 @@ def __init__( """ if not isinstance(kernel, ScaleKernel): base_kernel = kernel - outputscale = torch.tensor( - 1.0, - dtype=base_kernel.lengthscale.dtype, - device=base_kernel.lengthscale.device, + outputscale = torch.ones(kernel.batch_shape).to( + dtype=kernel.lengthscale.dtype, + device=kernel.lengthscale.device, ) else: base_kernel = kernel.base_kernel diff --git a/botorch_community/models/gp_regression_multisource.py b/botorch_community/models/gp_regression_multisource.py index 29aba27f8c..3af1d49cbc 100644 --- a/botorch_community/models/gp_regression_multisource.py +++ b/botorch_community/models/gp_regression_multisource.py @@ -20,6 +20,7 @@ from __future__ import annotations +from copy import deepcopy from typing import Optional import torch @@ -231,9 +232,9 @@ def _init_fit_gp( train_X, train_Y, train_Yvar, - likelihood=likelihood, - covar_module=covar_module, - mean_module=mean_module, + likelihood=deepcopy(likelihood), + covar_module=deepcopy(covar_module), + mean_module=deepcopy(mean_module), outcome_transform=outcome_transform, input_transform=input_transform, ) diff --git a/test/acquisition/test_analytic.py b/test/acquisition/test_analytic.py index 1262a7f3f0..0cb94bd3ef 100644 --- a/test/acquisition/test_analytic.py +++ b/test/acquisition/test_analytic.py @@ -964,6 +964,7 @@ def _test_noisy_expected_improvement( X_observed = get_train_inputs(model, transformed=False)[0] nfan = 5 + torch.manual_seed(1) nEI = NoisyExpectedImprovement(model, X_observed, num_fantasies=nfan) LogNEI = LogNoisyExpectedImprovement(model, X_observed, num_fantasies=nfan) # before assigning, check that the attributes exist diff --git a/test/models/test_converter.py b/test/models/test_converter.py index f36d2e6e7e..bcacbaa5f5 100644 --- a/test/models/test_converter.py +++ b/test/models/test_converter.py @@ -21,9 +21,10 @@ ) from botorch.models.transforms.input import AppendFeatures, Normalize from botorch.models.transforms.outcome import Standardize +from botorch.models.utils.gpytorch_modules import get_matern_kernel_with_gamma_prior from botorch.utils.test_helpers import SimpleGPyTorchModel from botorch.utils.testing import BotorchTestCase -from gpytorch.kernels import RBFKernel +from gpytorch.kernels import MaternKernel, RBFKernel from gpytorch.likelihoods import GaussianLikelihood from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from gpytorch.priors import LogNormalPrior @@ -133,14 +134,19 @@ def test_model_list_to_batched(self): with self.assertRaises(UnsupportedError): model_list_to_batched(ModelListGP(gp1, gp2)) # check scalar agreement - gp2 = SingleTaskGP(train_X, train_Y2) - gp2.likelihood.noise_covar.noise_prior.rate.fill_(1.0) + # modified to check the scalar agreement in a parameter that is accessible + # since the error is going to slip through for the non-parametrizable + # priors regardless (like the LogNormal) + with self.assertRaises(UnsupportedError): + model_list_to_batched(ModelListGP(gp1, gp2)) + + gp2.likelihood.noise_covar.raw_noise_constraint.lower_bound.fill_(1e-3) with self.assertRaises(UnsupportedError): model_list_to_batched(ModelListGP(gp1, gp2)) # check tensor shape agreement gp2 = SingleTaskGP(train_X, train_Y2) - gp2.covar_module.raw_outputscale = torch.nn.Parameter( - torch.tensor([0.0], device=self.device, dtype=dtype) + gp2.likelihood.noise_covar.raw_noise = torch.nn.Parameter( + torch.tensor([[0.42]], device=self.device, dtype=dtype) ) with self.assertRaises(UnsupportedError): model_list_to_batched(ModelListGP(gp1, gp2)) @@ -155,14 +161,15 @@ def test_model_list_to_batched(self): with self.assertRaises(NotImplementedError): model_list_to_batched(ModelListGP(gp2)) # test non-default kernel - gp1 = SingleTaskGP(train_X, train_Y1, covar_module=RBFKernel()) - gp2 = SingleTaskGP(train_X, train_Y2, covar_module=RBFKernel()) + gp1 = SingleTaskGP(train_X, train_Y1, covar_module=MaternKernel()) + gp2 = SingleTaskGP(train_X, train_Y2, covar_module=MaternKernel()) list_gp = ModelListGP(gp1, gp2) batch_gp = model_list_to_batched(list_gp) - self.assertEqual(type(batch_gp.covar_module), RBFKernel) + self.assertEqual(type(batch_gp.covar_module), MaternKernel) # test error when component GPs have different kernel types - gp1 = SingleTaskGP(train_X, train_Y1, covar_module=RBFKernel()) - gp2 = SingleTaskGP(train_X, train_Y2) + # added types for both default and non-default kernels for clarity + gp1 = SingleTaskGP(train_X, train_Y1, covar_module=MaternKernel()) + gp2 = SingleTaskGP(train_X, train_Y2, covar_module=RBFKernel()) list_gp = ModelListGP(gp1, gp2) with self.assertRaises(UnsupportedError): model_list_to_batched(list_gp) @@ -244,6 +251,23 @@ def test_model_list_to_batched(self): with self.assertRaises(UnsupportedError): model_list_to_batched(list_gp) + def test_model_list_to_batched_with_legacy_prior(self) -> None: + train_X = torch.rand(10, 2, device=self.device, dtype=torch.double) + # and test the old prior for completeness & test coverage + gp1_gamma = SingleTaskGP( + train_X, + train_X.sum(dim=-1, keepdim=True), + covar_module=get_matern_kernel_with_gamma_prior(train_X.shape[-1]), + ) + gp2_gamma = SingleTaskGP( + train_X, + train_X.sum(dim=-1, keepdim=True), + covar_module=get_matern_kernel_with_gamma_prior(train_X.shape[-1]), + ) + gp1_gamma.covar_module.base_kernel.lengthscale_prior.rate.fill_(1.0) + with self.assertRaises(UnsupportedError): + model_list_to_batched(ModelListGP(gp1_gamma, gp2_gamma)) + def test_model_list_to_batched_with_different_prior(self) -> None: # The goal is to test priors that didn't have their parameters # recorded in the state dict prior to GPyTorch #2551. diff --git a/test/models/test_deterministic.py b/test/models/test_deterministic.py index edb7de22b5..794ef08e78 100644 --- a/test/models/test_deterministic.py +++ b/test/models/test_deterministic.py @@ -172,7 +172,7 @@ def test_FixedSingleSampleModel(self): post = model.posterior(test_X) original_output = post.mean + post.variance.sqrt() * w fss_output = fss_model(test_X) - self.assertTrue(torch.equal(original_output, fss_output)) + self.assertAllClose(original_output, fss_output) self.assertTrue(hasattr(fss_model, "num_outputs")) diff --git a/test/models/test_gp_regression.py b/test/models/test_gp_regression.py index 0f6f4056ee..547598a432 100644 --- a/test/models/test_gp_regression.py +++ b/test/models/test_gp_regression.py @@ -23,7 +23,7 @@ from botorch.utils.sampling import manual_seed from botorch.utils.test_helpers import get_pvar_expected from botorch.utils.testing import _get_random_data, BotorchTestCase -from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel +from gpytorch.kernels import RBFKernel from gpytorch.likelihoods import ( _GaussianLikelihoodBase, FixedNoiseGaussianLikelihood, @@ -33,7 +33,7 @@ from gpytorch.means import ConstantMean, ZeroMean from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from gpytorch.mlls.noise_model_added_loss_term import NoiseModelAddedLossTerm -from gpytorch.priors import GammaPrior +from gpytorch.priors import LogNormalPrior class TestGPRegressionBase(BotorchTestCase): @@ -96,10 +96,10 @@ def test_gp(self, double_only: bool = False): # test init self.assertIsInstance(model.mean_module, ConstantMean) - self.assertIsInstance(model.covar_module, ScaleKernel) - matern_kernel = model.covar_module.base_kernel - self.assertIsInstance(matern_kernel, MaternKernel) - self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior) + self.assertIsInstance(model.covar_module, RBFKernel) + rbf_kernel = model.covar_module + self.assertIsInstance(rbf_kernel, RBFKernel) + self.assertIsInstance(rbf_kernel.lengthscale_prior, LogNormalPrior) if use_octf: self.assertIsInstance(model.outcome_transform, Standardize) if use_intf: diff --git a/test/models/test_model_list_gp_regression.py b/test/models/test_model_list_gp_regression.py index ee1e19dad7..148f8dcf43 100644 --- a/test/models/test_model_list_gp_regression.py +++ b/test/models/test_model_list_gp_regression.py @@ -25,7 +25,7 @@ from botorch.sampling.normal import IIDNormalSampler from botorch.utils.testing import _get_random_data, BotorchTestCase from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal -from gpytorch.kernels import MaternKernel, ScaleKernel +from gpytorch.kernels import RBFKernel from gpytorch.likelihoods import LikelihoodList from gpytorch.likelihoods.gaussian_likelihood import ( FixedNoiseGaussianLikelihood, @@ -34,7 +34,7 @@ from gpytorch.means import ConstantMean from gpytorch.mlls import SumMarginalLogLikelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood -from gpytorch.priors import GammaPrior +from gpytorch.priors import LogNormalPrior from torch import Tensor @@ -104,10 +104,8 @@ def _base_test_ModelListGP( self.assertEqual(model.num_outputs, 2) for m in model.models: self.assertIsInstance(m.mean_module, ConstantMean) - self.assertIsInstance(m.covar_module, ScaleKernel) - matern_kernel = m.covar_module.base_kernel - self.assertIsInstance(matern_kernel, MaternKernel) - self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior) + self.assertIsInstance(m.covar_module, RBFKernel) + self.assertIsInstance(m.covar_module.lengthscale_prior, LogNormalPrior) if outcome_transform != "None": self.assertIsInstance( m.outcome_transform, (Log, Standardize, ChainedOutcomeTransform) diff --git a/test/optim/test_fit.py b/test/optim/test_fit.py index 69775f3385..55bcf33908 100644 --- a/test/optim/test_fit.py +++ b/test/optim/test_fit.py @@ -29,9 +29,9 @@ def setUp(self) -> None: self.mlls = {} with torch.random.fork_rng(): torch.manual_seed(0) - train_X = torch.linspace(0, 1, 10).unsqueeze(-1) - train_Y = torch.sin((2 * math.pi) * train_X) - train_Y = train_Y + 0.1 * torch.randn_like(train_Y) + train_X = torch.linspace(0, 1, 30).unsqueeze(-1) + train_Y = torch.sin((6 * math.pi) * train_X) + train_Y = train_Y + 0.01 * torch.randn_like(train_Y) model = SingleTaskGP( train_X=train_X, diff --git a/test/optim/utils/test_model_utils.py b/test/optim/utils/test_model_utils.py index d1a0bb132c..a0ab24c222 100644 --- a/test/optim/utils/test_model_utils.py +++ b/test/optim/utils/test_model_utils.py @@ -6,6 +6,7 @@ from __future__ import annotations +import itertools import re import warnings from copy import deepcopy @@ -16,6 +17,10 @@ import torch from botorch import settings from botorch.models import SingleTaskGP +from botorch.models.utils.gpytorch_modules import ( + get_covar_module_with_dim_scaled_prior, + get_matern_kernel_with_gamma_prior, +) from botorch.optim.utils import ( get_data_loader, get_name_filter, @@ -158,10 +163,18 @@ def test__get_name_filter(self) -> None: class TestSampleAllPriors(BotorchTestCase): def test_sample_all_priors(self): - for dtype in (torch.float, torch.double): + for dtype, covar_module in itertools.product( + (torch.float, torch.double), + ( + get_covar_module_with_dim_scaled_prior(ard_num_dims=5), + get_matern_kernel_with_gamma_prior(ard_num_dims=5), + ), + ): train_X = torch.rand(3, 5, device=self.device, dtype=dtype) train_Y = torch.rand(3, 1, device=self.device, dtype=dtype) - model = SingleTaskGP(train_X=train_X, train_Y=train_Y) + model = SingleTaskGP( + train_X=train_X, train_Y=train_Y, covar_module=covar_module + ) mll = ExactMarginalLogLikelihood(model.likelihood, model) mll.to(device=self.device, dtype=dtype) original_state_dict = dict(deepcopy(mll.model.state_dict())) @@ -173,7 +186,10 @@ def test_sample_all_priors(self): != original_state_dict["likelihood.noise_covar.raw_noise"] ) # check that lengthscales are all different - ls = model.covar_module.base_kernel.raw_lengthscale.view(-1).tolist() + if isinstance(model.covar_module, ScaleKernel): + ls = model.covar_module.base_kernel.raw_lengthscale.view(-1).tolist() + else: + ls = model.covar_module.raw_lengthscale.view(-1).tolist() self.assertTrue(all(ls[0] != ls[i]) for i in range(1, len(ls))) # change one of the priors to a dummy prior that does not support sampling diff --git a/test/test_fit.py b/test/test_fit.py index 9e9ac6d644..4541abe315 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -27,7 +27,7 @@ from botorch.settings import debug from botorch.utils.context_managers import module_rollback_ctx, TensorCheckpoint from botorch.utils.testing import BotorchTestCase -from gpytorch.kernels import MaternKernel +from gpytorch.kernels import RBFKernel from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO from linear_operator.utils.errors import NotPSDError @@ -136,8 +136,7 @@ def setUp(self, suppress_input_warnings: bool = True) -> None: input_transform=Normalize(d=1), outcome_transform=Standardize(m=output_dim), ) - self.assertIsInstance(model.covar_module.base_kernel, MaternKernel) - model.covar_module.base_kernel.nu = 2.5 + self.assertIsInstance(model.covar_module, RBFKernel) mll = ExactMarginalLogLikelihood(model.likelihood, model) for dtype in (torch.float32, torch.float64): diff --git a/test_community/acquisition/test_multi_source.py b/test_community/acquisition/test_multi_source.py index 864c9a898d..cff9f6710a 100644 --- a/test_community/acquisition/test_multi_source.py +++ b/test_community/acquisition/test_multi_source.py @@ -6,8 +6,11 @@ import torch from botorch.exceptions import UnsupportedError +from botorch.models.utils.gpytorch_modules import ( + get_gaussian_likelihood_with_gamma_prior, + get_matern_kernel_with_gamma_prior, +) from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior - from botorch_community.acquisition.augmented_multisource import ( AugmentedUpperConfidenceBound, ) @@ -21,11 +24,17 @@ def _get_mock_agp(self, batch_shape, dtype): rep_shape = batch_shape + torch.Size([1, 1]) train_X = train_X.repeat(rep_shape) train_Y = train_Y.repeat(rep_shape) + covar_module = get_matern_kernel_with_gamma_prior( + ard_num_dims=train_X.shape[-1] - 1, + ) model_kwargs = { "train_X": train_X, "train_Y": train_Y, + "covar_module": covar_module, + "likelihood": get_gaussian_likelihood_with_gamma_prior(), } model = SingleTaskAugmentedGP(**model_kwargs) + return model def test_upper_confidence_bound(self): diff --git a/test_community/models/test_gp_regression_multisource.py b/test_community/models/test_gp_regression_multisource.py index 3e08358e1e..93282efe49 100644 --- a/test_community/models/test_gp_regression_multisource.py +++ b/test_community/models/test_gp_regression_multisource.py @@ -14,6 +14,10 @@ from botorch.exceptions import InputDataError, OptimizationWarning from botorch.models import SingleTaskGP from botorch.models.transforms import Normalize, Standardize +from botorch.models.utils.gpytorch_modules import ( + get_gaussian_likelihood_with_gamma_prior, + get_matern_kernel_with_gamma_prior, +) from botorch.posteriors import GPyTorchPosterior from botorch.sampling import SobolQMCNormalSampler from botorch.utils.test_helpers import get_pvar_expected @@ -65,6 +69,12 @@ def _get_model_and_data( "train_Yvar": torch.full_like(train_Y, 0.01) if train_Yvar else None, "outcome_transform": outcome_transform, "input_transform": input_transform, + "covar_module": get_matern_kernel_with_gamma_prior( + ard_num_dims=train_X.shape[-1] - 1 + ), + "likelihood": ( + None if train_Yvar else get_gaussian_likelihood_with_gamma_prior() + ), } model = SingleTaskAugmentedGP(**model_kwargs, **extra_model_kwargs) return model, model_kwargs @@ -109,8 +119,18 @@ def test_get_reliable_observation(self): true_y = torch.sin(x).reshape(-1, 1) y = torch.cos(x).reshape(-1, 1) - model0 = SingleTaskGP(x, true_y) - model1 = SingleTaskGP(x, y) + model0 = SingleTaskGP( + x, + true_y, + covar_module=get_matern_kernel_with_gamma_prior(x.shape[-1]), + likelihood=get_gaussian_likelihood_with_gamma_prior(), + ) + model1 = SingleTaskGP( + x, + y, + covar_module=get_matern_kernel_with_gamma_prior(x.shape[-1]), + likelihood=get_gaussian_likelihood_with_gamma_prior(), + ) res = _get_reliable_observations(model0, model1, x) true_res = torch.cat([torch.arange(0, 5, 1), torch.arange(9, 15, 1)]).int() diff --git a/tutorials/baxus.ipynb b/tutorials/baxus.ipynb index d9fb0a3d69..69ac7af62e 100644 --- a/tutorials/baxus.ipynb +++ b/tutorials/baxus.ipynb @@ -486,7 +486,7 @@ "\n", " # Scale the TR to be proportional to the lengthscales\n", " x_center = X[Y.argmax(), :].clone()\n", - " weights = model.covar_module.base_kernel.lengthscale.detach().view(-1)\n", + " weights = model.covar_module.lengthscale.detach().view(-1)\n", " weights = weights / weights.mean()\n", " weights = weights / torch.prod(weights.pow(1.0 / len(weights)))\n", " tr_lb = torch.clamp(x_center - weights * state.length, -1.0, 1.0)\n", diff --git a/tutorials/constraint_active_search.ipynb b/tutorials/constraint_active_search.ipynb index 8c36a56c7e..e2ca3b90a5 100644 --- a/tutorials/constraint_active_search.ipynb +++ b/tutorials/constraint_active_search.ipynb @@ -249,7 +249,7 @@ " return radius * r * z\n", "\n", " def _get_base_point_mask(self, X):\n", - " distance_matrix = self.model.models[0].covar_module.base_kernel.covar_dist(\n", + " distance_matrix = self.model.models[0].covar_module.covar_dist(\n", " X, self.base_points\n", " )\n", " return smooth_mask(distance_matrix, self.punchout_radius)\n", diff --git a/tutorials/fit_model_with_torch_optimizer.ipynb b/tutorials/fit_model_with_torch_optimizer.ipynb index 4d46338183..e3b1682516 100644 --- a/tutorials/fit_model_with_torch_optimizer.ipynb +++ b/tutorials/fit_model_with_torch_optimizer.ipynb @@ -20,6 +20,7 @@ "outputs": [], "source": [ "import math\n", + "\n", "import torch\n", "\n", "# use a GPU if available\n", @@ -134,7 +135,7 @@ "source": [ "from torch.optim import SGD\n", "\n", - "optimizer = SGD([{\"params\": model.parameters()}], lr=0.1)" + "optimizer = SGD([{\"params\": model.parameters()}], lr=0.025)" ] }, { @@ -190,7 +191,7 @@ " if (epoch + 1) % 10 == 0:\n", " print(\n", " f\"Epoch {epoch+1:>3}/{NUM_EPOCHS} - Loss: {loss.item():>4.3f} \"\n", - " f\"lengthscale: {model.covar_module.base_kernel.lengthscale.item():>4.3f} \"\n", + " f\"lengthscale: {model.covar_module.lengthscale.item():>4.3f} \"\n", " f\"noise: {model.likelihood.noise.item():>4.3f}\"\n", " )\n", " optimizer.step()" @@ -215,7 +216,7 @@ "outputs": [], "source": [ "# set model (and likelihood)\n", - "model.eval();" + "model.eval()" ] }, { diff --git a/tutorials/ibnn_bo.ipynb b/tutorials/ibnn_bo.ipynb index 9b17b40936..371b34245c 100644 --- a/tutorials/ibnn_bo.ipynb +++ b/tutorials/ibnn_bo.ipynb @@ -213,7 +213,7 @@ "\n", "plot_posterior(axs[2], model_matern)\n", "axs[2].set_title(\"GP (Matern Kernel)\\nLength Scale: %.2f\" % \n", - " model_matern.covar_module.base_kernel.lengthscale.item(), \n", + " model_matern.covar_module.lengthscale.item(), \n", " fontsize=20)\n", "axs[2].set_ylim(-7, 8)\n", "\n", diff --git a/tutorials/information_theoretic_acquisition_functions.ipynb b/tutorials/information_theoretic_acquisition_functions.ipynb index 4c7e4f84a7..61dcebab67 100644 --- a/tutorials/information_theoretic_acquisition_functions.ipynb +++ b/tutorials/information_theoretic_acquisition_functions.ipynb @@ -260,13 +260,13 @@ "import os\n", "\n", "import matplotlib.pyplot as plt\n", - "import torch\n", "import numpy as np\n", - "from botorch.utils.sampling import draw_sobol_samples\n", - "from botorch.models.transforms.outcome import Standardize\n", + "import torch\n", + "from botorch.fit import fit_gpytorch_mll\n", "from botorch.models.gp_regression import SingleTaskGP\n", + "from botorch.models.transforms.outcome import Standardize\n", + "from botorch.utils.sampling import draw_sobol_samples\n", "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood\n", - "from botorch.fit import fit_gpytorch_mll\n", "\n", "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")\n", "tkwargs = {\"dtype\": torch.double, \"device\": \"cpu\"}\n", @@ -369,12 +369,10 @@ "source": [ "from botorch.acquisition.utils import get_optimal_samples\n", "\n", - "num_samples = 32\n", + "num_samples = 12\n", "\n", "optimal_inputs, optimal_outputs = get_optimal_samples(\n", - " model,\n", - " bounds=bounds,\n", - " num_optima=num_samples\n", + " model, bounds=bounds, num_optima=num_samples\n", ")" ] }, @@ -397,17 +395,15 @@ }, "outputs": [], "source": [ - "from botorch.acquisition.predictive_entropy_search import qPredictiveEntropySearch\n", - "from botorch.acquisition.max_value_entropy_search import (\n", - " qLowerBoundMaxValueEntropy,\n", - ")\n", "from botorch.acquisition.joint_entropy_search import qJointEntropySearch\n", + "from botorch.acquisition.max_value_entropy_search import qLowerBoundMaxValueEntropy\n", + "from botorch.acquisition.predictive_entropy_search import qPredictiveEntropySearch\n", "\n", "pes = qPredictiveEntropySearch(model=model, optimal_inputs=optimal_inputs)\n", "\n", "# Here we use the lower bound estimates for the MES and JES\n", "# Note that the single-objective MES interface is slightly different,\n", - "# as it utilizes the Gumbel max-value approximation internally and \n", + "# as it utilizes the Gumbel max-value approximation internally and\n", "# therefore does not take the max values as input.\n", "mes_lb = qLowerBoundMaxValueEntropy(\n", " model=model,\n", @@ -452,14 +448,18 @@ " pes_X = pes_X / pes_X.max()\n", " mes_lb_X = mes_lb_X / mes_lb_X.max()\n", " jes_lb_X = jes_lb_X / jes_lb_X.max()\n", - " \n", + "\n", "plt.plot(X, pes_X, color=\"mediumseagreen\", linewidth=3, label=\"PES\")\n", "plt.plot(X, mes_lb_X, color=\"crimson\", linewidth=3, label=\"MES-LB\")\n", "plt.plot(X, jes_lb_X, color=\"dodgerblue\", linewidth=3, label=\"JES-LB\")\n", "\n", - "plt.vlines(X[pes_X.argmax()], 0, 1, color=\"mediumseagreen\", linewidth=1.5, linestyle='--')\n", - "plt.vlines(X[mes_lb_X.argmax()], 0, 1, color=\"crimson\", linewidth=1.5, linestyle=':')\n", - "plt.vlines(X[jes_lb_X.argmax()], 0, 1, color=\"dodgerblue\", linewidth=1.5, linestyle='--')\n", + "plt.vlines(\n", + " X[pes_X.argmax()], 0, 1, color=\"mediumseagreen\", linewidth=1.5, linestyle=\"--\"\n", + ")\n", + "plt.vlines(X[mes_lb_X.argmax()], 0, 1, color=\"crimson\", linewidth=1.5, linestyle=\":\")\n", + "plt.vlines(\n", + " X[jes_lb_X.argmax()], 0, 1, color=\"dodgerblue\", linewidth=1.5, linestyle=\"--\"\n", + ")\n", "plt.legend(fontsize=15)\n", "plt.xlabel(\"$x$\", fontsize=15)\n", "plt.ylabel(r\"$\\alpha(x)$\", fontsize=15)\n", @@ -491,8 +491,8 @@ " acq_function=pes,\n", " bounds=bounds,\n", " q=1,\n", - " num_restarts=10,\n", - " raw_samples=512,\n", + " num_restarts=4,\n", + " raw_samples=256,\n", " options={\"with_grad\": False},\n", ")\n", "print(\"PES: candidate={}, acq_value={}\".format(candidate, acq_value))\n", @@ -501,8 +501,8 @@ " acq_function=mes_lb,\n", " bounds=bounds,\n", " q=1,\n", - " num_restarts=10,\n", - " raw_samples=512,\n", + " num_restarts=4,\n", + " raw_samples=256,\n", ")\n", "print(\"MES-LB: candidate={}, acq_value={}\".format(candidate, acq_value))\n", "\n", @@ -510,8 +510,8 @@ " acq_function=jes_lb,\n", " bounds=bounds,\n", " q=1,\n", - " num_restarts=10,\n", - " raw_samples=512,\n", + " num_restarts=4,\n", + " raw_samples=256,\n", ")\n", "print(\"JES-LB: candidate={}, acq_value={}\".format(candidate, acq_value))" ] @@ -541,18 +541,19 @@ }, "outputs": [], "source": [ - "from botorch.test_functions.multi_objective import ZDT1\n", "from botorch.acquisition.multi_objective.utils import (\n", - " sample_optimal_points,\n", + " compute_sample_box_decomposition,\n", " random_search_optimizer,\n", - " compute_sample_box_decomposition\n", + " sample_optimal_points,\n", ")\n", + "from botorch.test_functions.multi_objective import ZDT1\n", + "\n", "d = 4\n", "M = 2\n", - "n = 16\n", + "n = 8\n", "\n", "if SMOKE_TEST:\n", - " q = 3\n", + " q = 2\n", "else:\n", " q = 4" ] @@ -592,12 +593,12 @@ }, "outputs": [], "source": [ - "num_pareto_samples = 10\n", - "num_pareto_points = 10\n", + "num_pareto_samples = 8\n", + "num_pareto_points = 8\n", "\n", "# We set the parameters for the random search\n", "optimizer_kwargs = {\n", - " \"pop_size\": 2000,\n", + " \"pop_size\": 500,\n", " \"max_tries\": 10,\n", "}\n", "\n", @@ -628,14 +629,14 @@ }, "outputs": [], "source": [ - "from botorch.acquisition.multi_objective.predictive_entropy_search import (\n", - " qMultiObjectivePredictiveEntropySearch,\n", + "from botorch.acquisition.multi_objective.joint_entropy_search import (\n", + " qLowerBoundMultiObjectiveJointEntropySearch,\n", ")\n", "from botorch.acquisition.multi_objective.max_value_entropy_search import (\n", " qLowerBoundMultiObjectiveMaxValueEntropySearch,\n", ")\n", - "from botorch.acquisition.multi_objective.joint_entropy_search import (\n", - " qLowerBoundMultiObjectiveJointEntropySearch,\n", + "from botorch.acquisition.multi_objective.predictive_entropy_search import (\n", + " qMultiObjectivePredictiveEntropySearch,\n", ")\n", "\n", "pes = qMultiObjectivePredictiveEntropySearch(model=model, pareto_sets=ps)\n", @@ -682,7 +683,7 @@ " acq_function=pes,\n", " bounds=bounds,\n", " q=q,\n", - " num_restarts=5,\n", + " num_restarts=4,\n", " raw_samples=512,\n", " options={\"with_grad\": False},\n", ")\n", @@ -693,7 +694,7 @@ " acq_function=mes_lb,\n", " bounds=bounds,\n", " q=q,\n", - " num_restarts=5,\n", + " num_restarts=4,\n", " raw_samples=512,\n", " sequential=True,\n", ")\n", @@ -714,7 +715,7 @@ " acq_function=jes_lb,\n", " bounds=bounds,\n", " q=q,\n", - " num_restarts=5,\n", + " num_restarts=4,\n", " raw_samples=512,\n", " sequential=True,\n", ")\n",