Skip to content

Commit

Permalink
Update the default SingleTaskGP prior (#2449)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#2610

Pull Request resolved: #2449

Update of the default hyperparameter priors for the SingleTaskGP.

Switch from the conventional Scale-Matern kernel with Gamma(3, 6) lengthscale prior is substituted for an RBF Kernel (without a ScaleKernel), and a change from the high-noise Gamma(1.1, 0.05) noise prior of the GaussianLikelihood to a LogNormal prior that prefers lower values. The change is made in accordance with the findings of [1].

 The change is made to improve the out-of-the-box performance of the BoTorch models on high-dimensional problems.

[1] Carl Hvarfner, Erik Orm Hellsten, Luigi Nardi. _Vanilla Bayesian Optimization Performs Great in High Dimensions_. ICML, 2024.

Reviewed By: saitcakmak

Differential Revision: D60080819
  • Loading branch information
Carl Hvarfner authored and facebook-github-bot committed Jul 30, 2024
1 parent 4497a5c commit f7732ac
Show file tree
Hide file tree
Showing 17 changed files with 163 additions and 93 deletions.
21 changes: 12 additions & 9 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from botorch.models.transforms.outcome import Log, OutcomeTransform
from botorch.models.utils import validate_input_scaling
from botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_gamma_prior,
get_matern_kernel_with_gamma_prior,
get_covar_module_with_dim_scaled_prior,
get_gaussian_likelihood_with_lognormal_prior,
MIN_INFERRED_NOISE_LEVEL,
)
from botorch.utils.containers import BotorchContainer
Expand All @@ -67,9 +67,13 @@
class SingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
r"""A single-task exact GP model, supporting both known and inferred noise levels.
A single-task exact GP using relatively strong priors on the Kernel
hyperparameters, which work best when covariates are normalized to the unit
cube and outcomes are standardized (zero mean, unit variance).
A single-task exact GP which, by default, utilizes hyperparameter priors
from [Hvarfner2024vanilla]_. These priors are fairly weak and designed to be
agnostic to the dimensionality of the problem. Moreover, they suggest a
moderately low level of noise. Importantly, The model works best when covariates
are normalized to the unit cube and outcomes are standardized (zero mean, unit
variance). For a detailed discussion on the hyperparameter priors, see
https://github.com/pytorch/botorch/discussions/2451.
This model works in batch mode (each batch having its own hyperparameters).
When the training observations include multiple outputs, this model will use
Expand Down Expand Up @@ -174,7 +178,7 @@ def __init__(
)
if likelihood is None:
if train_Yvar is None:
likelihood = get_gaussian_likelihood_with_gamma_prior(
likelihood = get_gaussian_likelihood_with_lognormal_prior(
batch_shape=self._aug_batch_shape
)
else:
Expand All @@ -190,14 +194,13 @@ def __init__(
mean_module = ConstantMean(batch_shape=self._aug_batch_shape)
self.mean_module = mean_module
if covar_module is None:
covar_module = get_matern_kernel_with_gamma_prior(
covar_module = get_covar_module_with_dim_scaled_prior(
ard_num_dims=transformed_X.shape[-1],
batch_shape=self._aug_batch_shape,
)
self._subset_batch_dict = {
"mean_module.raw_constant": -1,
"covar_module.raw_outputscale": -1,
"covar_module.base_kernel.raw_lengthscale": -3,
"covar_module.raw_lengthscale": -3,
}
if train_Yvar is None:
self._subset_batch_dict["likelihood.noise_covar.raw_noise"] = -2
Expand Down
7 changes: 3 additions & 4 deletions botorch/utils/gp_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,9 @@ def __init__(
"""
if not isinstance(kernel, ScaleKernel):
base_kernel = kernel
outputscale = torch.tensor(
1.0,
dtype=base_kernel.lengthscale.dtype,
device=base_kernel.lengthscale.device,
outputscale = torch.ones(kernel.batch_shape).to(
dtype=kernel.lengthscale.dtype,
device=kernel.lengthscale.device,
)
else:
base_kernel = kernel.base_kernel
Expand Down
7 changes: 4 additions & 3 deletions botorch_community/models/gp_regression_multisource.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from __future__ import annotations

from copy import deepcopy
from typing import Optional

import torch
Expand Down Expand Up @@ -231,9 +232,9 @@ def _init_fit_gp(
train_X,
train_Y,
train_Yvar,
likelihood=likelihood,
covar_module=covar_module,
mean_module=mean_module,
likelihood=deepcopy(likelihood),
covar_module=deepcopy(covar_module),
mean_module=deepcopy(mean_module),
outcome_transform=outcome_transform,
input_transform=input_transform,
)
Expand Down
44 changes: 34 additions & 10 deletions test/models/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
)
from botorch.models.transforms.input import AppendFeatures, Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.models.utils.gpytorch_modules import get_matern_kernel_with_gamma_prior
from botorch.utils.test_helpers import SimpleGPyTorchModel
from botorch.utils.testing import BotorchTestCase
from gpytorch.kernels import RBFKernel
from gpytorch.kernels import MaternKernel, RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.priors import LogNormalPrior
Expand Down Expand Up @@ -133,14 +134,19 @@ def test_model_list_to_batched(self):
with self.assertRaises(UnsupportedError):
model_list_to_batched(ModelListGP(gp1, gp2))
# check scalar agreement
gp2 = SingleTaskGP(train_X, train_Y2)
gp2.likelihood.noise_covar.noise_prior.rate.fill_(1.0)
# modified to check the scalar agreement in a parameter that is accessible
# since the error is going to slip through for the non-parametrizable
# priors regardless (like the LogNormal)
with self.assertRaises(UnsupportedError):
model_list_to_batched(ModelListGP(gp1, gp2))

gp2.likelihood.noise_covar.raw_noise_constraint.lower_bound.fill_(1e-3)
with self.assertRaises(UnsupportedError):
model_list_to_batched(ModelListGP(gp1, gp2))
# check tensor shape agreement
gp2 = SingleTaskGP(train_X, train_Y2)
gp2.covar_module.raw_outputscale = torch.nn.Parameter(
torch.tensor([0.0], device=self.device, dtype=dtype)
gp2.likelihood.noise_covar.raw_noise = torch.nn.Parameter(
torch.tensor([[0.42]], device=self.device, dtype=dtype)
)
with self.assertRaises(UnsupportedError):
model_list_to_batched(ModelListGP(gp1, gp2))
Expand All @@ -155,14 +161,15 @@ def test_model_list_to_batched(self):
with self.assertRaises(NotImplementedError):
model_list_to_batched(ModelListGP(gp2))
# test non-default kernel
gp1 = SingleTaskGP(train_X, train_Y1, covar_module=RBFKernel())
gp2 = SingleTaskGP(train_X, train_Y2, covar_module=RBFKernel())
gp1 = SingleTaskGP(train_X, train_Y1, covar_module=MaternKernel())
gp2 = SingleTaskGP(train_X, train_Y2, covar_module=MaternKernel())
list_gp = ModelListGP(gp1, gp2)
batch_gp = model_list_to_batched(list_gp)
self.assertEqual(type(batch_gp.covar_module), RBFKernel)
self.assertEqual(type(batch_gp.covar_module), MaternKernel)
# test error when component GPs have different kernel types
gp1 = SingleTaskGP(train_X, train_Y1, covar_module=RBFKernel())
gp2 = SingleTaskGP(train_X, train_Y2)
# added types for both default and non-default kernels for clarity
gp1 = SingleTaskGP(train_X, train_Y1, covar_module=MaternKernel())
gp2 = SingleTaskGP(train_X, train_Y2, covar_module=RBFKernel())
list_gp = ModelListGP(gp1, gp2)
with self.assertRaises(UnsupportedError):
model_list_to_batched(list_gp)
Expand Down Expand Up @@ -244,6 +251,23 @@ def test_model_list_to_batched(self):
with self.assertRaises(UnsupportedError):
model_list_to_batched(list_gp)

def test_model_list_to_batched_with_legacy_prior(self) -> None:
train_X = torch.rand(10, 2, device=self.device, dtype=torch.double)
# and test the old prior for completeness & test coverage
gp1_gamma = SingleTaskGP(
train_X,
train_X.sum(dim=-1, keepdim=True),
covar_module=get_matern_kernel_with_gamma_prior(train_X.shape[-1]),
)
gp2_gamma = SingleTaskGP(
train_X,
train_X.sum(dim=-1, keepdim=True),
covar_module=get_matern_kernel_with_gamma_prior(train_X.shape[-1]),
)
gp1_gamma.covar_module.base_kernel.lengthscale_prior.rate.fill_(1.0)
with self.assertRaises(UnsupportedError):
model_list_to_batched(ModelListGP(gp1_gamma, gp2_gamma))

def test_model_list_to_batched_with_different_prior(self) -> None:
# The goal is to test priors that didn't have their parameters
# recorded in the state dict prior to GPyTorch #2551.
Expand Down
2 changes: 1 addition & 1 deletion test/models/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_FixedSingleSampleModel(self):
post = model.posterior(test_X)
original_output = post.mean + post.variance.sqrt() * w
fss_output = fss_model(test_X)
self.assertTrue(torch.equal(original_output, fss_output))
self.assertAllClose(original_output, fss_output)

self.assertTrue(hasattr(fss_model, "num_outputs"))

Expand Down
12 changes: 6 additions & 6 deletions test/models/test_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from botorch.utils.sampling import manual_seed
from botorch.utils.test_helpers import get_pvar_expected
from botorch.utils.testing import _get_random_data, BotorchTestCase
from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import (
_GaussianLikelihoodBase,
FixedNoiseGaussianLikelihood,
Expand All @@ -33,7 +33,7 @@
from gpytorch.means import ConstantMean, ZeroMean
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.noise_model_added_loss_term import NoiseModelAddedLossTerm
from gpytorch.priors import GammaPrior
from gpytorch.priors import LogNormalPrior


class TestGPRegressionBase(BotorchTestCase):
Expand Down Expand Up @@ -96,10 +96,10 @@ def test_gp(self, double_only: bool = False):

# test init
self.assertIsInstance(model.mean_module, ConstantMean)
self.assertIsInstance(model.covar_module, ScaleKernel)
matern_kernel = model.covar_module.base_kernel
self.assertIsInstance(matern_kernel, MaternKernel)
self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)
self.assertIsInstance(model.covar_module, RBFKernel)
rbf_kernel = model.covar_module
self.assertIsInstance(rbf_kernel, RBFKernel)
self.assertIsInstance(rbf_kernel.lengthscale_prior, LogNormalPrior)
if use_octf:
self.assertIsInstance(model.outcome_transform, Standardize)
if use_intf:
Expand Down
10 changes: 4 additions & 6 deletions test/models/test_model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.testing import _get_random_data, BotorchTestCase
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import LikelihoodList
from gpytorch.likelihoods.gaussian_likelihood import (
FixedNoiseGaussianLikelihood,
Expand All @@ -34,7 +34,7 @@
from gpytorch.means import ConstantMean
from gpytorch.mlls import SumMarginalLogLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.priors import GammaPrior
from gpytorch.priors import LogNormalPrior
from torch import Tensor


Expand Down Expand Up @@ -104,10 +104,8 @@ def _base_test_ModelListGP(
self.assertEqual(model.num_outputs, 2)
for m in model.models:
self.assertIsInstance(m.mean_module, ConstantMean)
self.assertIsInstance(m.covar_module, ScaleKernel)
matern_kernel = m.covar_module.base_kernel
self.assertIsInstance(matern_kernel, MaternKernel)
self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)
self.assertIsInstance(m.covar_module, RBFKernel)
self.assertIsInstance(m.covar_module.lengthscale_prior, LogNormalPrior)
if outcome_transform != "None":
self.assertIsInstance(
m.outcome_transform, (Log, Standardize, ChainedOutcomeTransform)
Expand Down
6 changes: 3 additions & 3 deletions test/optim/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def setUp(self) -> None:
self.mlls = {}
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
train_Y = torch.sin((2 * math.pi) * train_X)
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)
train_X = torch.linspace(0, 1, 30).unsqueeze(-1)
train_Y = torch.sin((6 * math.pi) * train_X)
train_Y = train_Y + 0.01 * torch.randn_like(train_Y)

model = SingleTaskGP(
train_X=train_X,
Expand Down
22 changes: 19 additions & 3 deletions test/optim/utils/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import itertools
import re
import warnings
from copy import deepcopy
Expand All @@ -16,6 +17,10 @@
import torch
from botorch import settings
from botorch.models import SingleTaskGP
from botorch.models.utils.gpytorch_modules import (
get_covar_module_with_dim_scaled_prior,
get_matern_kernel_with_gamma_prior,
)
from botorch.optim.utils import (
get_data_loader,
get_name_filter,
Expand Down Expand Up @@ -158,10 +163,18 @@ def test__get_name_filter(self) -> None:

class TestSampleAllPriors(BotorchTestCase):
def test_sample_all_priors(self):
for dtype in (torch.float, torch.double):
for dtype, covar_module in itertools.product(
(torch.float, torch.double),
(
get_covar_module_with_dim_scaled_prior(ard_num_dims=5),
get_matern_kernel_with_gamma_prior(ard_num_dims=5),
),
):
train_X = torch.rand(3, 5, device=self.device, dtype=dtype)
train_Y = torch.rand(3, 1, device=self.device, dtype=dtype)
model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
model = SingleTaskGP(
train_X=train_X, train_Y=train_Y, covar_module=covar_module
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
mll.to(device=self.device, dtype=dtype)
original_state_dict = dict(deepcopy(mll.model.state_dict()))
Expand All @@ -173,7 +186,10 @@ def test_sample_all_priors(self):
!= original_state_dict["likelihood.noise_covar.raw_noise"]
)
# check that lengthscales are all different
ls = model.covar_module.base_kernel.raw_lengthscale.view(-1).tolist()
if isinstance(model.covar_module, ScaleKernel):
ls = model.covar_module.base_kernel.raw_lengthscale.view(-1).tolist()
else:
ls = model.covar_module.raw_lengthscale.view(-1).tolist()
self.assertTrue(all(ls[0] != ls[i]) for i in range(1, len(ls)))

# change one of the priors to a dummy prior that does not support sampling
Expand Down
5 changes: 2 additions & 3 deletions test/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from botorch.settings import debug
from botorch.utils.context_managers import module_rollback_ctx, TensorCheckpoint
from botorch.utils.testing import BotorchTestCase
from gpytorch.kernels import MaternKernel
from gpytorch.kernels import RBFKernel
from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO
from linear_operator.utils.errors import NotPSDError

Expand Down Expand Up @@ -136,8 +136,7 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
input_transform=Normalize(d=1),
outcome_transform=Standardize(m=output_dim),
)
self.assertIsInstance(model.covar_module.base_kernel, MaternKernel)
model.covar_module.base_kernel.nu = 2.5
self.assertIsInstance(model.covar_module, RBFKernel)

mll = ExactMarginalLogLikelihood(model.likelihood, model)
for dtype in (torch.float32, torch.float64):
Expand Down
11 changes: 10 additions & 1 deletion test_community/acquisition/test_multi_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

import torch
from botorch.exceptions import UnsupportedError
from botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_gamma_prior,
get_matern_kernel_with_gamma_prior,
)
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior

from botorch_community.acquisition.augmented_multisource import (
AugmentedUpperConfidenceBound,
)
Expand All @@ -21,11 +24,17 @@ def _get_mock_agp(self, batch_shape, dtype):
rep_shape = batch_shape + torch.Size([1, 1])
train_X = train_X.repeat(rep_shape)
train_Y = train_Y.repeat(rep_shape)
covar_module = get_matern_kernel_with_gamma_prior(
ard_num_dims=train_X.shape[-1] - 1,
)
model_kwargs = {
"train_X": train_X,
"train_Y": train_Y,
"covar_module": covar_module,
"likelihood": get_gaussian_likelihood_with_gamma_prior(),
}
model = SingleTaskAugmentedGP(**model_kwargs)

return model

def test_upper_confidence_bound(self):
Expand Down
Loading

0 comments on commit f7732ac

Please sign in to comment.