Skip to content

Commit

Permalink
Update the default SingleTaskGP prior
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/botorch#2449

See title

Differential Revision: D60080819
  • Loading branch information
Carl Hvarfner authored and facebook-github-bot committed Jul 29, 2024
1 parent af3029f commit 503c128
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 34 deletions.
5 changes: 3 additions & 2 deletions ax/models/tests/test_botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import math
from copy import deepcopy
from unittest import mock
from unittest.mock import Mock
Expand Down Expand Up @@ -66,9 +67,9 @@ def test_get_model(self) -> None:
self.assertIsInstance(model, SingleTaskGP)
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
self.assertEqual(
model.covar_module.base_kernel.lengthscale_prior.concentration, 3.0
model.covar_module.lengthscale_prior.loc, math.log(2.0) / 2 + 2**0.5
)
self.assertEqual(model.covar_module.base_kernel.lengthscale_prior.rate, 6.0)
self.assertEqual(model.covar_module.lengthscale_prior.scale, 3**0.5)
model = _get_model(X=x, Y=y, Yvar=unknown_var, task_feature=1)
self.assertIs(type(model), MultiTaskGP) # Don't accept subclasses.
self.assertIsInstance(model.likelihood, GaussianLikelihood)
Expand Down
34 changes: 13 additions & 21 deletions ax/models/tests/test_botorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from botorch.models.transforms.input import Warp
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.objective import get_objective_weights_transform
from gpytorch.kernels.constant_kernel import ConstantKernel
from gpytorch.likelihoods import _GaussianLikelihoodBase
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood
Expand Down Expand Up @@ -558,19 +559,12 @@ def test_BotorchModel(

# Test loading state dict
true_state_dict = {
"mean_module.raw_constant": 3.5004,
"covar_module.raw_outputscale": 2.2438,
"covar_module.base_kernel.raw_lengthscale": [
[-0.9274, -0.9274, -0.9274]
],
"covar_module.base_kernel.raw_lengthscale_constraint.lower_bound": 0.1,
"covar_module.base_kernel.raw_lengthscale_constraint.upper_bound": 2.5,
"covar_module.base_kernel.lengthscale_prior.concentration": 3.0,
"covar_module.base_kernel.lengthscale_prior.rate": 6.0,
"covar_module.raw_outputscale_constraint.lower_bound": 0.2,
"covar_module.raw_outputscale_constraint.upper_bound": 2.6,
"covar_module.outputscale_prior.concentration": 2.0,
"covar_module.outputscale_prior.rate": 0.15,
"mean_module.raw_constant": 1.0,
"covar_module.raw_lengthscale": [[0.3548, 0.3548, 0.3548]],
"covar_module.lengthscale_prior._transformed_loc": 1.9635,
"covar_module.lengthscale_prior._transformed_scale": 1.7321,
"covar_module.raw_lengthscale_constraint.lower_bound": 0.0250,
"covar_module.raw_lengthscale_constraint.upper_bound": float("inf"),
}
true_state_dict = {
key: torch.tensor(val, **tkwargs)
Expand All @@ -591,8 +585,7 @@ def test_BotorchModel(

# Test for some change in model parameters & buffer for refit_model=True
true_state_dict["mean_module.raw_constant"] += 0.1
true_state_dict["covar_module.raw_outputscale"] += 0.1
true_state_dict["covar_module.base_kernel.raw_lengthscale"] += 0.1
true_state_dict["covar_module.raw_lengthscale"] += 0.1
model = get_and_fit_model(
Xs=Xs1,
Ys=Ys1,
Expand Down Expand Up @@ -774,17 +767,16 @@ def test_get_feature_importances_from_botorch_model(self) -> None:
train_X = torch.rand(5, 3, **tkwargs)
train_Y = train_X.sum(dim=-1, keepdim=True)
simple_gp = SingleTaskGP(train_X=train_X, train_Y=train_Y)
simple_gp.covar_module.base_kernel.lengthscale = torch.tensor(
[1, 3, 5], **tkwargs
)
simple_gp.covar_module.lengthscale = torch.tensor([1, 3, 5], **tkwargs)
importances = get_feature_importances_from_botorch_model(simple_gp)
self.assertTrue(np.allclose(importances, np.array([15 / 23, 5 / 23, 3 / 23])))
self.assertEqual(importances.shape, (1, 1, 3))
# Model with no base kernel
simple_gp.covar_module.base_kernel = None
# Model with kernel that has no lengthscales
simple_gp.covar_module = ConstantKernel()
with self.assertRaisesRegex(
NotImplementedError,
"Failed to extract lengthscales from `m.covar_module.base_kernel`",
"Failed to extract lengthscales from `m.covar_module` and "
"`m.covar_module.base_kernel`",
):
get_feature_importances_from_botorch_model(simple_gp)

Expand Down
10 changes: 8 additions & 2 deletions ax/models/torch/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,15 +562,21 @@ def get_feature_importances_from_botorch_model(
lengthscales = []
for m in models:
try:
ls = m.covar_module.base_kernel.lengthscale
# this can be a ModelList of a SAAS and STGP, so this is a necessary way
# to get the lengthscale
if hasattr(m.covar_module, "base_kernel"):
ls = m.covar_module.base_kernel.lengthscale
else:
ls = m.covar_module.lengthscale
except AttributeError:
ls = None
if ls is None or ls.shape[-1] != m.train_inputs[0].shape[-1]:
# TODO: We could potentially set the feature importances to NaN in this
# case, but this require knowing the batch dimension of this model.
# Consider supporting in the future.
raise NotImplementedError(
"Failed to extract lengthscales from `m.covar_module.base_kernel`"
"Failed to extract lengthscales from `m.covar_module` "
"and `m.covar_module.base_kernel`"
)
if ls.ndim == 2:
ls = ls.unsqueeze(0)
Expand Down
9 changes: 5 additions & 4 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,8 @@ def test_feature_importances(self) -> None:
self.assertEqual(importances.shape, (1, 1, 3))
saas_model = deepcopy(model.surrogate.model)
else:
model.surrogate.model.covar_module.base_kernel.lengthscale = (
torch.tensor([1, 2, 3], **self.tkwargs)
model.surrogate.model.covar_module.lengthscale = torch.tensor(
[1, 2, 3], **self.tkwargs
)
importances = model.feature_importances()
self.assertTrue(
Expand All @@ -659,11 +659,12 @@ def test_feature_importances(self) -> None:
)
self.assertEqual(importances.shape, (2, 1, 3))
# Add model we don't support
vanilla_model.covar_module.base_kernel = None
vanilla_model.covar_module = None
model.surrogate._model = vanilla_model # pyre-ignore
with self.assertRaisesRegex(
NotImplementedError,
"Failed to extract lengthscales from `m.covar_module.base_kernel`",
"Failed to extract lengthscales from `m.covar_module` "
"and `m.covar_module.base_kernel`",
):
model.feature_importances()
# Test model is None
Expand Down
5 changes: 4 additions & 1 deletion ax/plot/tests/test_feature_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def get_sensitivity_values(ax_model: ModelBridge) -> Dict:
Returns map {'metric_name': {'parameter_name': sensitivity_value}}
"""
ls = ax_model.model.model.covar_module.base_kernel.lengthscale.squeeze()
if hasattr(ax_model.model.model.covar_module, "outputscale"):
ls = ax_model.model.model.covar_module.base_kernel.lengthscale.squeeze()
else:
ls = ax_model.model.model.covar_module.lengthscale.squeeze()
if len(ls.shape) > 1:
ls = ls.mean(dim=0)
# pyre-fixme[16]: `float` has no attribute `detach`.
Expand Down
16 changes: 12 additions & 4 deletions ax/utils/sensitivity/derivative_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = "rbf") -> Tensor:
D = X.shape[1]
N = X.shape[0]
n = x.shape[0]
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
if hasattr(gp.covar_module, "outputscale"):
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
sigma_f = gp.covar_module.outputscale.detach()
else:
lengthscale = gp.covar_module.lengthscale.detach()
sigma_f = 1.0
if kernel_type == "rbf":
K_xX = gp.covar_module(x, X).evaluate()
part1 = -torch.eye(D, device=x.device, dtype=x.dtype) / lengthscale**2
Expand All @@ -52,7 +57,6 @@ def get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = "rbf") -> Tensor:
constant_component = (-5.0 / 3.0) * distance - (5.0 * math.sqrt(5.0) / 3.0) * (
distance**2
)
sigma_f = gp.covar_module.outputscale.detach()
part1 = torch.eye(D, device=lengthscale.device) / lengthscale
part2 = (x1_.view(n, 1, D) - x2_.view(1, N, D)) / distance.unsqueeze(2)
total_k = sigma_f * constant_component * exp_component
Expand All @@ -70,8 +74,12 @@ def get_Kxx_dx2(gp: Model, kernel_type: str = "rbf") -> Tensor:
"""
X = gp.train_inputs[0]
D = X.shape[1]
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
sigma_f = gp.covar_module.outputscale.detach()
if hasattr(gp.covar_module, "outputscale"):
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
sigma_f = gp.covar_module.outputscale.detach()
else:
lengthscale = gp.covar_module.lengthscale.detach()
sigma_f = 1.0
res = (torch.eye(D, device=lengthscale.device) / lengthscale**2) * sigma_f
if kernel_type == "rbf":
return res
Expand Down

0 comments on commit 503c128

Please sign in to comment.