From 3a1a1621b2e0ab195292b37c2c940102805133ed Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Tue, 30 Jul 2024 10:49:17 -0700 Subject: [PATCH] Update the default SingleTaskGP prior (#2610) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2610 X-link: https://github.com/pytorch/botorch/pull/2449 Update of the default hyperparameter priors for the SingleTaskGP. Switch from the conventional Scale-Matern kernel with Gamma(3, 6) lengthscale prior is substituted for an RBF Kernel (without a ScaleKernel), and a change from the high-noise Gamma(1.1, 0.05) noise prior of the GaussianLikelihood to a LogNormal prior that prefers lower values. The change is made in accordance with the findings of [1]. The change is made to improve the out-of-the-box performance of the BoTorch models on high-dimensional problems. [1] Carl Hvarfner, Erik Orm Hellsten, Luigi Nardi. _Vanilla Bayesian Optimization Performs Great in High Dimensions_. ICML, 2024. Reviewed By: saitcakmak Differential Revision: D60080819 --- ax/models/tests/test_botorch_defaults.py | 5 ++-- ax/models/tests/test_botorch_model.py | 34 +++++++++-------------- ax/models/torch/botorch.py | 10 +++++-- ax/models/torch/tests/test_model.py | 9 +++--- ax/plot/tests/test_feature_importances.py | 5 +++- ax/utils/sensitivity/derivative_gp.py | 16 ++++++++--- 6 files changed, 45 insertions(+), 34 deletions(-) diff --git a/ax/models/tests/test_botorch_defaults.py b/ax/models/tests/test_botorch_defaults.py index 989d7994bea..fcd4f17ef1c 100644 --- a/ax/models/tests/test_botorch_defaults.py +++ b/ax/models/tests/test_botorch_defaults.py @@ -6,6 +6,7 @@ # pyre-strict +import math from copy import deepcopy from unittest import mock from unittest.mock import Mock @@ -66,9 +67,9 @@ def test_get_model(self) -> None: self.assertIsInstance(model, SingleTaskGP) self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood) self.assertEqual( - model.covar_module.base_kernel.lengthscale_prior.concentration, 3.0 + model.covar_module.lengthscale_prior.loc, math.log(2.0) / 2 + 2**0.5 ) - self.assertEqual(model.covar_module.base_kernel.lengthscale_prior.rate, 6.0) + self.assertEqual(model.covar_module.lengthscale_prior.scale, 3**0.5) model = _get_model(X=x, Y=y, Yvar=unknown_var, task_feature=1) self.assertIs(type(model), MultiTaskGP) # Don't accept subclasses. self.assertIsInstance(model.likelihood, GaussianLikelihood) diff --git a/ax/models/tests/test_botorch_model.py b/ax/models/tests/test_botorch_model.py index 9e42efd6d14..0a18f5ff149 100644 --- a/ax/models/tests/test_botorch_model.py +++ b/ax/models/tests/test_botorch_model.py @@ -36,6 +36,7 @@ from botorch.models.transforms.input import Warp from botorch.utils.datasets import SupervisedDataset from botorch.utils.objective import get_objective_weights_transform +from gpytorch.kernels.constant_kernel import ConstantKernel from gpytorch.likelihoods import _GaussianLikelihoodBase from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood @@ -558,19 +559,12 @@ def test_BotorchModel( # Test loading state dict true_state_dict = { - "mean_module.raw_constant": 3.5004, - "covar_module.raw_outputscale": 2.2438, - "covar_module.base_kernel.raw_lengthscale": [ - [-0.9274, -0.9274, -0.9274] - ], - "covar_module.base_kernel.raw_lengthscale_constraint.lower_bound": 0.1, - "covar_module.base_kernel.raw_lengthscale_constraint.upper_bound": 2.5, - "covar_module.base_kernel.lengthscale_prior.concentration": 3.0, - "covar_module.base_kernel.lengthscale_prior.rate": 6.0, - "covar_module.raw_outputscale_constraint.lower_bound": 0.2, - "covar_module.raw_outputscale_constraint.upper_bound": 2.6, - "covar_module.outputscale_prior.concentration": 2.0, - "covar_module.outputscale_prior.rate": 0.15, + "mean_module.raw_constant": 1.0, + "covar_module.raw_lengthscale": [[0.3548, 0.3548, 0.3548]], + "covar_module.lengthscale_prior._transformed_loc": 1.9635, + "covar_module.lengthscale_prior._transformed_scale": 1.7321, + "covar_module.raw_lengthscale_constraint.lower_bound": 0.0250, + "covar_module.raw_lengthscale_constraint.upper_bound": float("inf"), } true_state_dict = { key: torch.tensor(val, **tkwargs) @@ -591,8 +585,7 @@ def test_BotorchModel( # Test for some change in model parameters & buffer for refit_model=True true_state_dict["mean_module.raw_constant"] += 0.1 - true_state_dict["covar_module.raw_outputscale"] += 0.1 - true_state_dict["covar_module.base_kernel.raw_lengthscale"] += 0.1 + true_state_dict["covar_module.raw_lengthscale"] += 0.1 model = get_and_fit_model( Xs=Xs1, Ys=Ys1, @@ -774,17 +767,16 @@ def test_get_feature_importances_from_botorch_model(self) -> None: train_X = torch.rand(5, 3, **tkwargs) train_Y = train_X.sum(dim=-1, keepdim=True) simple_gp = SingleTaskGP(train_X=train_X, train_Y=train_Y) - simple_gp.covar_module.base_kernel.lengthscale = torch.tensor( - [1, 3, 5], **tkwargs - ) + simple_gp.covar_module.lengthscale = torch.tensor([1, 3, 5], **tkwargs) importances = get_feature_importances_from_botorch_model(simple_gp) self.assertTrue(np.allclose(importances, np.array([15 / 23, 5 / 23, 3 / 23]))) self.assertEqual(importances.shape, (1, 1, 3)) - # Model with no base kernel - simple_gp.covar_module.base_kernel = None + # Model with kernel that has no lengthscales + simple_gp.covar_module = ConstantKernel() with self.assertRaisesRegex( NotImplementedError, - "Failed to extract lengthscales from `m.covar_module.base_kernel`", + "Failed to extract lengthscales from `m.covar_module` and " + "`m.covar_module.base_kernel`", ): get_feature_importances_from_botorch_model(simple_gp) diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index 11329ab8b40..81cf017bc48 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -562,7 +562,12 @@ def get_feature_importances_from_botorch_model( lengthscales = [] for m in models: try: - ls = m.covar_module.base_kernel.lengthscale + # this can be a ModelList of a SAAS and STGP, so this is a necessary way + # to get the lengthscale + if hasattr(m.covar_module, "base_kernel"): + ls = m.covar_module.base_kernel.lengthscale + else: + ls = m.covar_module.lengthscale except AttributeError: ls = None if ls is None or ls.shape[-1] != m.train_inputs[0].shape[-1]: @@ -570,7 +575,8 @@ def get_feature_importances_from_botorch_model( # case, but this require knowing the batch dimension of this model. # Consider supporting in the future. raise NotImplementedError( - "Failed to extract lengthscales from `m.covar_module.base_kernel`" + "Failed to extract lengthscales from `m.covar_module` " + "and `m.covar_module.base_kernel`" ) if ls.ndim == 2: ls = ls.unsqueeze(0) diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index cc2ab32f16a..9f9273f635b 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -634,8 +634,8 @@ def test_feature_importances(self) -> None: self.assertEqual(importances.shape, (1, 1, 3)) saas_model = deepcopy(model.surrogate.model) else: - model.surrogate.model.covar_module.base_kernel.lengthscale = ( - torch.tensor([1, 2, 3], **self.tkwargs) + model.surrogate.model.covar_module.lengthscale = torch.tensor( + [1, 2, 3], **self.tkwargs ) importances = model.feature_importances() self.assertTrue( @@ -658,11 +658,12 @@ def test_feature_importances(self) -> None: ) self.assertEqual(importances.shape, (2, 1, 3)) # Add model we don't support - vanilla_model.covar_module.base_kernel = None + vanilla_model.covar_module = None model.surrogate._model = vanilla_model # pyre-ignore with self.assertRaisesRegex( NotImplementedError, - "Failed to extract lengthscales from `m.covar_module.base_kernel`", + "Failed to extract lengthscales from `m.covar_module` " + "and `m.covar_module.base_kernel`", ): model.feature_importances() # Test model is None diff --git a/ax/plot/tests/test_feature_importances.py b/ax/plot/tests/test_feature_importances.py index dbb952e92a6..b315d6e9620 100644 --- a/ax/plot/tests/test_feature_importances.py +++ b/ax/plot/tests/test_feature_importances.py @@ -47,7 +47,10 @@ def get_sensitivity_values(ax_model: ModelBridge) -> Dict: Returns map {'metric_name': {'parameter_name': sensitivity_value}} """ - ls = ax_model.model.model.covar_module.base_kernel.lengthscale.squeeze() + if hasattr(ax_model.model.model.covar_module, "outputscale"): + ls = ax_model.model.model.covar_module.base_kernel.lengthscale.squeeze() + else: + ls = ax_model.model.model.covar_module.lengthscale.squeeze() if len(ls.shape) > 1: ls = ls.mean(dim=0) # pyre-fixme[16]: `float` has no attribute `detach`. diff --git a/ax/utils/sensitivity/derivative_gp.py b/ax/utils/sensitivity/derivative_gp.py index 207a9ee7065..4f6a520a5f6 100644 --- a/ax/utils/sensitivity/derivative_gp.py +++ b/ax/utils/sensitivity/derivative_gp.py @@ -37,7 +37,12 @@ def get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = "rbf") -> Tensor: D = X.shape[1] N = X.shape[0] n = x.shape[0] - lengthscale = gp.covar_module.base_kernel.lengthscale.detach() + if hasattr(gp.covar_module, "outputscale"): + lengthscale = gp.covar_module.base_kernel.lengthscale.detach() + sigma_f = gp.covar_module.outputscale.detach() + else: + lengthscale = gp.covar_module.lengthscale.detach() + sigma_f = 1.0 if kernel_type == "rbf": K_xX = gp.covar_module(x, X).evaluate() part1 = -torch.eye(D, device=x.device, dtype=x.dtype) / lengthscale**2 @@ -52,7 +57,6 @@ def get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = "rbf") -> Tensor: constant_component = (-5.0 / 3.0) * distance - (5.0 * math.sqrt(5.0) / 3.0) * ( distance**2 ) - sigma_f = gp.covar_module.outputscale.detach() part1 = torch.eye(D, device=lengthscale.device) / lengthscale part2 = (x1_.view(n, 1, D) - x2_.view(1, N, D)) / distance.unsqueeze(2) total_k = sigma_f * constant_component * exp_component @@ -70,8 +74,12 @@ def get_Kxx_dx2(gp: Model, kernel_type: str = "rbf") -> Tensor: """ X = gp.train_inputs[0] D = X.shape[1] - lengthscale = gp.covar_module.base_kernel.lengthscale.detach() - sigma_f = gp.covar_module.outputscale.detach() + if hasattr(gp.covar_module, "outputscale"): + lengthscale = gp.covar_module.base_kernel.lengthscale.detach() + sigma_f = gp.covar_module.outputscale.detach() + else: + lengthscale = gp.covar_module.lengthscale.detach() + sigma_f = 1.0 res = (torch.eye(D, device=lengthscale.device) / lengthscale**2) * sigma_f if kernel_type == "rbf": return res