Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom kernel priors #219

Merged
merged 5 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Deserialization is now also possible from optional class name abbreviations
- `Kernel` base class allowing to specify kernels
- `MaternKernel` class can be chosen for GP surrogates
- `hypothesis` strategies and roundtrip test for kernels, constraints, objectives and acquisition
functions
- `hypothesis` strategies and roundtrip test for kernels, constraints, objectives, priors
and acquisition functions
- New acquisition functions: `qSR`, `qNEI`, `LogEI`, `qLogEI`, `qLogNEI`
- `GammaPrior` can now be chosen as lengthscale prior
AVHopp marked this conversation as resolved.
Show resolved Hide resolved

### Changed
AVHopp marked this conversation as resolved.
Show resolved Hide resolved
- Reorganized acquisition.py into `acquisition` subpackage
Expand Down
12 changes: 11 additions & 1 deletion baybe/kernels/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Base classes for all kernels."""

from abc import ABC
from typing import Optional

from attrs import define
from attrs import define, field

from baybe.kernels.priors.base import Prior
from baybe.serialization.core import (
converter,
get_base_structure_hook,
Expand All @@ -17,12 +19,20 @@
class Kernel(ABC, SerialMixin):
"""Abstract base class for all kernels."""

lengthscale_prior: Optional[Prior] = field(default=None, kw_only=True)
"""An optional prior on the kernel lengthscale."""

def to_gpytorch(self, *args, **kwargs):
"""Create the gpytorch representation of the kernel."""
import gpytorch.kernels

kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__)
fields_dict = filter_attributes(object=self, callable_=kernel_cls.__init__)

# If a lengthscale prior was chosen, we manually add it to the dictionary
if self.lengthscale_prior is not None:
fields_dict["lengthscale_prior"] = self.lengthscale_prior.to_gpytorch()

# Update kwargs to contain class-specific attributes
kwargs.update(fields_dict)

Expand Down
2 changes: 1 addition & 1 deletion baybe/kernels/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _convert_fraction(value: Union[str, float, Fraction], /) -> float:
return float(value)


@define
@define(frozen=True)
class MaternKernel(Kernel):
"""A Matern kernel using a smoothness parameter."""

Expand Down
5 changes: 5 additions & 0 deletions baybe/kernels/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Available priors."""

from baybe.kernels.priors.basic import GammaPrior

__all__ = ["GammaPrior"]
35 changes: 35 additions & 0 deletions baybe/kernels/priors/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Base class for all priors."""

from abc import ABC

from attrs import define

from baybe.serialization.core import (
converter,
get_base_structure_hook,
unstructure_base,
)
from baybe.serialization.mixin import SerialMixin
from baybe.utils.basic import filter_attributes


@define(frozen=True)
class Prior(ABC, SerialMixin):
"""Abstract base class for all priors."""

def to_gpytorch(self, *args, **kwargs):
"""Create the gpytorch representation of the prior."""
import gpytorch.priors

prior_cls = getattr(gpytorch.priors, self.__class__.__name__)
fields_dict = filter_attributes(object=self, callable_=prior_cls.__init__)

# Update kwargs to contain class-specific attributes
AVHopp marked this conversation as resolved.
Show resolved Hide resolved
kwargs.update(fields_dict)

return prior_cls(*args, **kwargs)


# Register de-/serialization hooks
converter.register_structure_hook(Prior, get_base_structure_hook(Prior))
converter.register_unstructure_hook(Prior, unstructure_base)
16 changes: 16 additions & 0 deletions baybe/kernels/priors/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Priors that can be used for kernels."""
from attrs import define, field
from attrs.validators import gt

from baybe.kernels.priors.base import Prior


@define(frozen=True)
class GammaPrior(Prior):
"""A Gamma prior parameterized by concentration and rate."""

concentration: float = field(converter=float, validator=gt(0.0))
"""The concentration."""

rate: float = field(converter=float, validator=gt(0.0))
"""The rate."""
17 changes: 11 additions & 6 deletions baybe/surrogates/gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, ClassVar, Optional

from attr import define, field

from baybe.kernels import MaternKernel
from baybe.kernels.base import Kernel
from baybe.kernels.priors import GammaPrior
from baybe.searchspace import SearchSpace
from baybe.surrogates.base import Surrogate

Expand All @@ -27,7 +28,7 @@ class GaussianProcessSurrogate(Surrogate):
# See base class.

# Object variables
kernel: Kernel = field(factory=MaternKernel)
kernel: Optional[Kernel] = field(default=None)
AVHopp marked this conversation as resolved.
Show resolved Hide resolved
"""The kernel used by the Gaussian Process."""

# TODO: type should be Optional[botorch.models.SingleTaskGP] but is currently
Expand All @@ -46,7 +47,6 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No
import botorch
import gpytorch
import torch
from gpytorch.priors import GammaPrior

# identify the indexes of the task and numeric dimensions
# TODO: generalize to multiple task parameters
Expand All @@ -72,6 +72,8 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No
train_x.shape[-1] >= 50
)

# TODO Until now, only the kernels use our custom priors, hence the explicit
# to_gpytorch() calls for all others
# low D priors
if train_x.shape[-1] < 10:
lengthscale_prior = [GammaPrior(1.2, 1.1), 0.2]
Expand Down Expand Up @@ -104,17 +106,20 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No
# create GP mean
mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)

# If no kernel is provided, we construct one from our priors
if self.kernel is None:
self.kernel = MaternKernel(lengthscale_prior=lengthscale_prior[0])
AdrianSosic marked this conversation as resolved.
Show resolved Hide resolved

# define the covariance module for the numeric dimensions
gpytorch_kernel = self.kernel.to_gpytorch(
ard_num_dims=train_x.shape[-1] - n_task_params,
active_dims=numeric_idxs,
batch_shape=batch_shape,
lengthscale_prior=lengthscale_prior[0],
)
base_covar_module = gpytorch.kernels.ScaleKernel(
gpytorch_kernel,
batch_shape=batch_shape,
outputscale_prior=outputscale_prior[0],
outputscale_prior=outputscale_prior[0].to_gpytorch(),
)
if outputscale_prior[1] is not None:
base_covar_module.outputscale = torch.tensor([outputscale_prior[1]])
Expand All @@ -136,7 +141,7 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No

# create GP likelihood
likelihood = gpytorch.likelihoods.GaussianLikelihood(
noise_prior=noise_prior[0], batch_shape=batch_shape
noise_prior=noise_prior[0].to_gpytorch(), batch_shape=batch_shape
)
if noise_prior[1] is not None:
likelihood.noise = torch.tensor([noise_prior[1]])
Expand Down
9 changes: 8 additions & 1 deletion tests/hypothesis_strategies/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,11 @@

from baybe.kernels import MaternKernel

matern_kernels = st.builds(MaternKernel, st.sampled_from((0.5, 1.5, 2.5)))
from ..hypothesis_strategies.priors import priors

matern_kernels = st.builds(
MaternKernel,
nu=st.sampled_from((0.5, 1.5, 2.5)),
lengthscale_prior=st.one_of(st.none(), priors),
)
"""A strategy that generates Matern kernels."""
19 changes: 19 additions & 0 deletions tests/hypothesis_strategies/priors.py
AVHopp marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Hypothesis strategies for priors."""
AVHopp marked this conversation as resolved.
Show resolved Hide resolved

import hypothesis.strategies as st

from baybe.kernels.priors import GammaPrior

gamma_priors = st.builds(
GammaPrior,
st.floats(min_value=0, exclude_min=True),
st.floats(min_value=0, exclude_min=True),
)
"""A strategy that generates Gamma priors."""

priors = st.one_of(
[
gamma_priors,
]
)
"""A strategy that generates priors."""
13 changes: 13 additions & 0 deletions tests/serialization/test_prior_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Test serialization of priors."""

from hypothesis import given

from baybe.kernels.priors.base import Prior
from tests.hypothesis_strategies.priors import priors


@given(priors)
def test_prior_kernel_roundtrip(prior: Prior):
string = prior.to_json()
prior2 = Prior.from_json(string)
assert prior == prior2, (prior, prior2)