From ac14ece95221ce94dd1adf5975adef7d31fe855e Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Thu, 19 Sep 2024 19:59:28 -0700 Subject: [PATCH] Support for priors in OAK Kernel (#2535) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2535 Added support for registering priors to the coefficients of the OrthogonalAdditiveKernel. Useful for incentivizing sparsity in the additive components, and improve identifiability between first- and second-order components. Differential Revision: D61730632 --- .../kernels/orthogonal_additive_kernel.py | 89 +++++++- .../test_orthogonal_additive_kernel.py | 191 +++++++++++++++++- 2 files changed, 278 insertions(+), 2 deletions(-) diff --git a/botorch/models/kernels/orthogonal_additive_kernel.py b/botorch/models/kernels/orthogonal_additive_kernel.py index 37293c540e..5fce524ef6 100644 --- a/botorch/models/kernels/orthogonal_additive_kernel.py +++ b/botorch/models/kernels/orthogonal_additive_kernel.py @@ -11,10 +11,16 @@ from botorch.exceptions.errors import UnsupportedError from gpytorch.constraints import Interval, Positive from gpytorch.kernels import Kernel +from gpytorch.module import Module +from gpytorch.priors import Prior from torch import nn, Tensor _positivity_constraint = Positive() +SECOND_ORDER_PRIOR_ERROR_MSG = ( + "Second order interactions are disabled, but there is a prior on the second order " + "coefficients. Please remove the second order prior or enable second order terms." +) class OrthogonalAdditiveKernel(Kernel): @@ -40,6 +46,9 @@ def __init__( dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, coeff_constraint: Interval = _positivity_constraint, + offset_prior: Optional[Prior] = None, + coeffs_1_prior: Optional[Prior] = None, + coeffs_2_prior: Optional[Prior] = None, ): """ Args: @@ -52,9 +61,18 @@ def __init__( dtype: Initialization dtype for required Tensors. device: Initialization device for required Tensors. coeff_constraint: Constraint on the coefficients of the additive kernel. + offset_prior: Prior on the offset coefficient. Should be prior with non- + negative support. + coeffs_1_prior: Prior on the parameter main effects. Should be prior with + non-negative support. + coeffs_2_prior: coeffs_1_prior: Prior on the parameter interactions. Should + be prior with non-negative support. """ super().__init__(batch_shape=batch_shape) self.base_kernel = base_kernel + if not second_order and coeffs_2_prior is not None: + raise AttributeError(SECOND_ORDER_PRIOR_ERROR_MSG) + # integration nodes, weights for [0, 1] tkwargs = {"dtype": dtype, "device": device} z, w = leggauss(deg=quad_deg, a=0, b=1, **tkwargs) @@ -82,6 +100,29 @@ def __init__( else None ), ) + if offset_prior is not None: + self.register_prior( + name="offset_prior", + prior=offset_prior, + param_or_closure=self._offset_param, + setting_closure=self._offset_closure, + ) + if coeffs_1_prior is not None: + self.register_prior( + name="coeffs_1_prior", + prior=coeffs_1_prior, + param_or_closure=self._coeffs_1_param, + setting_closure=self._coeffs_1_closure, + ) + if coeffs_2_prior is not None: + self.register_prior( + name="coeffs_2_prior", + prior=coeffs_2_prior, + param_or_closure=self._coeffs_2_param, + setting_closure=self._coeffs_2_closure, + ) + + # for second order interactions, we only if second_order: self._rev_triu_indices = torch.tensor( _reverse_triu_indices(dim), @@ -95,7 +136,7 @@ def __init__( self.coeff_constraint = coeff_constraint self.dim = dim - def k(self, x1, x2) -> Tensor: + def k(self, x1: Tensor, x2: Tensor) -> Tensor: """Evaluates the kernel matrix base_kernel(x1, x2) on each input dimension independently. @@ -140,6 +181,52 @@ def coeffs_2(self) -> Optional[Tensor]: else: return None + def _coeffs_1_param(self, m: Module) -> Tensor: + return m.coeffs_1 + + def _coeffs_2_param(self, m: Module) -> Tensor: + return m.coeffs_2 + + def _offset_param(self, m: Module) -> Tensor: + return m.offset + + def _coeffs_1_closure(self, m: Module, v: Tensor) -> Tensor: + return m._set_coeffs_1(v) + + def _coeffs_2_closure(self, m: Module, v: Tensor) -> Tensor: + return m._set_coeffs_2(v) + + def _offset_closure(self, m: Module, v: Tensor) -> Tensor: + return m._set_offset(v) + + def _set_coeffs_1(self, value: Tensor): + value = torch.as_tensor(value).to(self.raw_coeffs_1) + value = value.expand(*self.batch_shape, self.dim) + self.initialize(raw_coeffs_1=self.coeff_constraint.inverse_transform(value)) + + def _set_coeffs_2(self, value: Tensor): + value = torch.as_tensor(value).to(self.raw_coeffs_1) + value = value.expand(*self.batch_shape, self.dim, self.dim) + row_idcs, col_idcs = torch.triu_indices(self.dim, self.dim, offset=1) + value = value[..., row_idcs, col_idcs].to(self.raw_coeffs_2) + self.initialize(raw_coeffs_2=self.coeff_constraint.inverse_transform(value)) + + def _set_offset(self, value: Tensor): + value = torch.as_tensor(value).to(self.raw_offset) + self.initialize(raw_offset=self.coeff_constraint.inverse_transform(value)) + + @coeffs_1.setter + def coeffs_1(self, value): + self._set_coeffs_1(value) + + @coeffs_2.setter + def coeffs_2(self, value): + self._set_coeffs_2(value) + + @offset.setter + def offset(self, value): + self._set_offset(value) + def forward( self, x1: Tensor, diff --git a/test/models/kernels/test_orthogonal_additive_kernel.py b/test/models/kernels/test_orthogonal_additive_kernel.py index 7f378b0034..7b39771c21 100644 --- a/test/models/kernels/test_orthogonal_additive_kernel.py +++ b/test/models/kernels/test_orthogonal_additive_kernel.py @@ -4,12 +4,23 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import itertools + import torch from botorch.exceptions.errors import UnsupportedError -from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel +from botorch.fit import fit_gpytorch_mll +from botorch.models import SingleTaskGP +from botorch.models.kernels.orthogonal_additive_kernel import ( + OrthogonalAdditiveKernel, + SECOND_ORDER_PRIOR_ERROR_MSG, +) from botorch.utils.testing import BotorchTestCase +from gpytorch.constraints import Positive from gpytorch.kernels import MaternKernel, RBFKernel from gpytorch.lazy import LazyEvaluatedKernelTensor +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood +from gpytorch.priors import LogNormalPrior +from gpytorch.priors.torch_priors import GammaPrior, HalfCauchyPrior, UniformPrior from torch import nn, Tensor @@ -118,6 +129,184 @@ def test_kernel(self): tol = 1e-5 self.assertTrue(((K_ortho @ oak.w).squeeze(-1) < tol).all()) + def test_priors(self): + d = 5 + dtypes = [torch.float, torch.double] + batch_shapes = [(), (2,), (7, 2)] + + # test no prior + oak = OrthogonalAdditiveKernel( + RBFKernel(), dim=d, batch_shape=None, second_order=True + ) + for dtype, batch_shape in itertools.product(dtypes, batch_shapes): + # test with default args and batch_shape = None in second_order + tkwargs = {"dtype": dtype, "device": self.device} + offset_prior = HalfCauchyPrior(0.1).to(**tkwargs) + coeffs_1_prior = LogNormalPrior(0, 1).to(**tkwargs) + coeffs_2_prior = GammaPrior(3, 6).to(**tkwargs) + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + second_order=True, + offset_prior=offset_prior, + coeffs_1_prior=coeffs_1_prior, + coeffs_2_prior=coeffs_2_prior, + batch_shape=batch_shape, + **tkwargs, + ) + + self.assertIsInstance(oak.offset_prior, HalfCauchyPrior) + self.assertIsInstance(oak.coeffs_1_prior, LogNormalPrior) + self.assertEqual(oak.coeffs_1_prior.scale, 1) + self.assertEqual(oak.coeffs_2_prior.concentration, 3) + + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + second_order=True, + coeffs_1_prior=None, + coeffs_2_prior=coeffs_2_prior, + batch_shape=batch_shape, + **tkwargs, + ) + self.assertEqual(oak.coeffs_2_prior.concentration, 3) + with self.assertRaisesRegex( + AttributeError, + "'OrthogonalAdditiveKernel' object has no attribute " "'coeffs_1_prior", + ): + _ = oak.coeffs_1_prior + # test with batch_shape = None in second_order + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + second_order=True, + coeffs_1_prior=coeffs_1_prior, + batch_shape=batch_shape, + **tkwargs, + ) + with self.assertRaisesRegex(AttributeError, SECOND_ORDER_PRIOR_ERROR_MSG): + OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=None, + second_order=False, + coeffs_2_prior=GammaPrior(1, 1), + ) + + # train the model to ensure that param setters are called + train_X = torch.rand(5, d, dtype=dtype, device=self.device) + train_Y = torch.randn(5, 1, dtype=dtype, device=self.device) + + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=None, + second_order=True, + offset_prior=offset_prior, + coeffs_1_prior=coeffs_1_prior, + coeffs_2_prior=coeffs_2_prior, + **tkwargs, + ) + model = SingleTaskGP(train_X=train_X, train_Y=train_Y, covar_module=oak) + mll = ExactMarginalLogLikelihood(model.likelihood, model) + fit_gpytorch_mll(mll, optimizer_kwargs={"options": {"maxiter": 2}}) + + unif_prior = UniformPrior(10, 11) + # coeff_constraint is not enforced so that we can check the raw parameter + # values and not the reshaped (triu transformed) ones + oak_for_sample = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=None, + second_order=True, + offset_prior=unif_prior, + coeffs_1_prior=unif_prior, + coeffs_2_prior=unif_prior, + coeff_constraint=Positive(transform=None, inv_transform=None), + **tkwargs, + ) + oak_for_sample.sample_from_prior("offset_prior") + oak_for_sample.sample_from_prior("coeffs_1_prior") + oak_for_sample.sample_from_prior("coeffs_2_prior") + + # check that all sampled values are within the bounds set by the priors + self.assertTrue(torch.all(10 <= oak_for_sample.raw_offset <= 11)) + self.assertTrue( + torch.all( + (10 <= oak_for_sample.raw_coeffs_1) + * (oak_for_sample.raw_coeffs_1 <= 11) + ) + ) + self.assertTrue( + torch.all( + (10 <= oak_for_sample.raw_coeffs_2) + * (oak_for_sample.raw_coeffs_2 <= 11) + ) + ) + + def test_set_coeffs(self): + d = 5 + dtype = torch.double + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=None, + second_order=True, + dtype=dtype, + ) + constraint = oak.coeff_constraint + coeffs_1 = torch.arange(d).to(dtype) + coeffs_2 = torch.ones((d * d)).reshape(d, d).triu().to(dtype) + oak.coeffs_1 = coeffs_1 + oak.coeffs_2 = coeffs_2 + + self.assertAllClose( + oak.raw_coeffs_1, + constraint.inverse_transform(coeffs_1), + ) + # raw_coeffs_2 has length d * (d-1) / 2 + self.assertAllClose( + oak.raw_coeffs_2, constraint.inverse_transform(torch.ones(10).to(dtype)) + ) + + batch_shapes = torch.Size([2]), torch.Size([5, 2]) + for batch_shape in batch_shapes: + dtype = torch.double + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=batch_shape, + second_order=True, + dtype=dtype, + coeff_constraint=Positive(transform=None, inv_transform=None), + ) + constraint = oak.coeff_constraint + coeffs_1 = torch.arange(d).to(dtype) + coeffs_2 = torch.ones((d * d)).reshape(d, d).triu().to(dtype) + oak.coeffs_1 = coeffs_1 + oak.coeffs_2 = coeffs_2 + + self.assertEqual(oak.raw_coeffs_1.shape, batch_shape + torch.Size([5])) + # raw_coeffs_2 has length d * (d-1) / 2 + self.assertEqual(oak.raw_coeffs_2.shape, batch_shape + torch.Size([10])) + + # test setting value as float + oak.offset = 0.5 + self.assertAllClose(oak.offset, 0.5 * torch.ones_like(oak.offset)) + # raw_coeffs_2 has length d * (d-1) / 2 + oak.coeffs_1 = 0.2 + self.assertAllClose( + oak.raw_coeffs_1, 0.2 * torch.ones_like(oak.raw_coeffs_1) + ) + oak.coeffs_2 = 0.3 + self.assertAllClose( + oak.raw_coeffs_2, 0.3 * torch.ones_like(oak.raw_coeffs_2) + ) + # the lower triangular part is set to 0 automatically since the + self.assertAllClose( + oak.coeffs_2.tril(diagonal=-1), torch.zeros_like(oak.coeffs_2) + ) + def isposdef(A: Tensor) -> bool: """Determines whether A is positive definite or not, by attempting a Cholesky