Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix kernel translation #311

Merged
merged 21 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
57295db
Fix bug by passing down argument to gpytorch kernel
AdrianSosic Jul 10, 2024
167f465
Improve attribute matching logic
AdrianSosic Jul 11, 2024
4d24fce
Add basic kernel validation tests
AdrianSosic Jul 11, 2024
e03b7b5
Add missing validators to basic kernels
AdrianSosic Jul 11, 2024
6559962
Add composite kernel validation tests
AdrianSosic Jul 11, 2024
ee8aa1b
Add missing validators to composite kernels
AdrianSosic Jul 11, 2024
436fb90
Fix kernel hypothesis strategies
AdrianSosic Jul 11, 2024
17080e6
Add kernel test module
AdrianSosic Jul 11, 2024
228f787
Set torch default dtype via workaround
AdrianSosic Jul 11, 2024
378b583
Fix optional keyword handling
AdrianSosic Jul 11, 2024
5348061
Test gpytorch kernel assembly
AdrianSosic Jul 12, 2024
6e0d799
Add validation for kernel keyword arguments
AdrianSosic Jul 12, 2024
3a7045f
Update CHANGELOG.md
AdrianSosic Jul 12, 2024
365898e
Move hypothesis strategy for positive fininte floats to basic.py
AdrianSosic Jul 15, 2024
c178d91
Change name of test helper function
AdrianSosic Jul 15, 2024
45cc0bc
Fix mypy issues
AdrianSosic Jul 15, 2024
143da7c
Handle instance attributes that are not relevant for matching
AdrianSosic Jul 15, 2024
241b9b6
Remove float conversion from example
AdrianSosic Jul 15, 2024
894c35c
Add class property indicating non-botorch attributes
AdrianSosic Jul 15, 2024
51e3492
Add intermediate kernel class layer
AdrianSosic Jul 16, 2024
b794ae6
Move dtype workaround to to_gpytorch method
AdrianSosic Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Parameter.is_numeric` has been replaced with `Parameter.is_numerical`
- `DiscreteParameter.transform_rep_exp2comp` has been replaced with
`DiscreteParameter.transform`
- `filter_attributes` has been replaced with `match_attributes`

### Added
- `Surrogate` base class now exposes a `to_botorch` method
Expand All @@ -38,6 +39,7 @@ _ `_optional` subpackage for managing optional dependencies
- `DiscreteParameter.to_subspace`, `ContinuousParameter.to_subspace` and
`Parameter.to_searchspace` convenience constructors
- Utilities for permutation and dependency data augmentation
- Validation and translation tests for kernels

### Changed
- Passing an `Objective` to `Campaign` is now optional
Expand All @@ -58,6 +60,10 @@ _ `_optional` subpackage for managing optional dependencies
- Serialization bug related to class layout of `SKLearnClusteringRecommender`
- `MetaRecommender`s no longer trigger warnings about non-empty objectives or
measurements when calling a `NonPredictiveRecommender`
- Bug introduced in 0.9.0 (PR #221, commit 3078f3), where arguments to `to_gpytorch`
are not passed on to the GPyTorch kernels
- Positive-valued kernel attributes are now correctly handled by validators
Scienfitz marked this conversation as resolved.
Show resolved Hide resolved
and hypothesis strategies

### Deprecations
- `SequentialGreedyRecommender` class replaced with `BotorchRecommender`
Expand Down
22 changes: 19 additions & 3 deletions baybe/acquisition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import ClassVar

import pandas as pd
from attrs import define
from attrs import define, fields

from baybe.searchspace import SearchSpace
from baybe.serialization.core import (
Expand All @@ -18,7 +18,7 @@
)
from baybe.serialization.mixin import SerialMixin
from baybe.surrogates.base import Surrogate
from baybe.utils.basic import classproperty, filter_attributes
from baybe.utils.basic import classproperty, match_attributes
from baybe.utils.boolean import is_abstract
from baybe.utils.dataframe import to_tensor

Expand All @@ -45,9 +45,25 @@ def to_botorch(
"""Create the botorch-ready representation of the function."""
import botorch.acquisition as botorch_acqf_module

from baybe.acquisition.acqfs import qNegIntegratedPosteriorVariance

# Retrieve corresponding botorch class
acqf_cls = getattr(botorch_acqf_module, self.__class__.__name__)
params_dict = filter_attributes(object=self, callable_=acqf_cls.__init__)

# Match relevant attributes
flds = fields(qNegIntegratedPosteriorVariance)
ignore = (
(
flds.sampling_n_points.name,
flds.sampling_method.name,
flds.sampling_fraction.name,
)
if isinstance(self, qNegIntegratedPosteriorVariance)
else ()
)
Scienfitz marked this conversation as resolved.
Show resolved Hide resolved
params_dict = match_attributes(self, acqf_cls.__init__, ignore=ignore)[0]

# Collect remaining (context-specific) parameters
signature_params = signature(acqf_cls).parameters
additional_params = {}
if "model" in signature_params:
Expand Down
4 changes: 4 additions & 0 deletions baybe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ class DeprecationError(Exception):

class UnidentifiedSubclassError(Exception):
"""A specified subclass cannot be found in the given class hierarchy."""


class UnmatchedAttributeError(Exception):
"""An attribute cannot be matched against a certain callable signature."""
56 changes: 38 additions & 18 deletions baybe/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
from __future__ import annotations

from abc import ABC
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from attrs import define

from baybe.exceptions import UnmatchedAttributeError
from baybe.priors.base import Prior
from baybe.serialization.core import (
converter,
get_base_structure_hook,
unstructure_base,
)
from baybe.serialization.mixin import SerialMixin
from baybe.utils.basic import filter_attributes, get_baseclasses
from baybe.utils.basic import get_baseclasses, match_attributes

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -42,38 +43,57 @@ def to_gpytorch(
"""Create the gpytorch representation of the kernel."""
import gpytorch.kernels

# Fetch the necessary gpytorch constructor parameters of the kernel.
# NOTE: In gpytorch, some attributes (like the kernel lengthscale) are handled
# via the `gpytorch.kernels.Kernel` base class. Hence, it is not sufficient to
# just check the fields of the actual class, but also those of the base class.
# Extract keywords with non-default values. This is required since gpytorch
# makes use of kwargs, i.e. differentiates if certain keywords are explicitly
# passed or not. For instance, `ard_num_dims = kwargs.get("ard_num_dims", 1)`
# fails if we explicitly pass `ard_num_dims=None`.
kw: dict[str, Any] = dict(
ard_num_dims=ard_num_dims, batch_shape=batch_shape, active_dims=active_dims
)
kw = {k: v for k, v in kw.items() if v is not None}

# Get corresponding gpytorch kernel class and its base classes
kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__)
base_classes = get_baseclasses(kernel_cls, abstract=True)
fields_dict = {}

# Fetch the necessary gpytorch constructor parameters of the kernel.
# NOTE: In gpytorch, some attributes (like the kernel lengthscale) are handled
# via the `gpytorch.kernels.Kernel` base class. Hence, it is not sufficient to
# just check the fields of the actual class, but also those of its base
# classes.
kernel_attrs: dict[str, Any] = {}
unmatched_attrs: dict[str, Any] = {}
for cls in [kernel_cls, *base_classes]:
fields_dict.update(filter_attributes(object=self, callable_=cls.__init__))
matched, unmatched = match_attributes(self, cls.__init__, strict=False)
kernel_attrs.update(matched)
unmatched_attrs.update(unmatched)

# Sanity check: all attributes of the BayBE kernel need a corresponding match
# in the gpytorch kernel (otherwise, the BayBE kernel class is misconfigured).
# Exception: initial values are not used during construction but are set
# on the created object (see code at the end of the method).
missing = set(unmatched) - set(kernel_attrs)
if leftover := {m for m in missing if not m.endswith("_initial_value")}:
raise UnmatchedAttributeError(leftover)

# Convert specified priors to gpytorch, if provided
prior_dict = {
key: value.to_gpytorch()
for key, value in fields_dict.items()
for key, value in kernel_attrs.items()
if isinstance(value, Prior)
}

# Convert specified inner kernels to gpytorch, if provided
kernel_dict = {
key: value.to_gpytorch(
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
active_dims=active_dims,
)
for key, value in fields_dict.items()
key: value.to_gpytorch(**kw)
AVHopp marked this conversation as resolved.
Show resolved Hide resolved
for key, value in kernel_attrs.items()
if isinstance(value, Kernel)
}

# Create the kernel with all its inner gpytorch objects
fields_dict.update(kernel_dict)
fields_dict.update(prior_dict)
gpytorch_kernel = kernel_cls(**fields_dict)
kernel_attrs.update(kernel_dict)
kernel_attrs.update(prior_dict)
gpytorch_kernel = kernel_cls(**kernel_attrs, **kw)

# If the kernel has a lengthscale, set its initial value
if kernel_cls.has_lengthscale:
Expand Down
38 changes: 28 additions & 10 deletions baybe/kernels/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from attrs import define, field
from attrs.converters import optional as optional_c
from attrs.validators import ge, in_, instance_of
from attrs.validators import ge, gt, in_, instance_of
from attrs.validators import optional as optional_v

from baybe.kernels.base import Kernel
Expand All @@ -22,7 +22,9 @@ class LinearKernel(Kernel):
"""An optional prior on the kernel variance parameter."""

variance_initial_value: float | None = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
default=None,
converter=optional_c(float),
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the kernel variance parameter."""

Expand Down Expand Up @@ -58,7 +60,9 @@ class MaternKernel(Kernel):
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: float | None = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
default=None,
converter=optional_c(float),
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the kernel lengthscale."""

Expand All @@ -73,7 +77,9 @@ class PeriodicKernel(Kernel):
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: float | None = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
default=None,
converter=optional_c(float),
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the kernel lengthscale."""

Expand All @@ -83,7 +89,9 @@ class PeriodicKernel(Kernel):
"""An optional prior on the kernel period length."""

period_length_initial_value: float | None = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
default=None,
converter=optional_c(float),
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the kernel period length."""

Expand Down Expand Up @@ -116,7 +124,9 @@ class PiecewisePolynomialKernel(Kernel):
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: float | None = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
default=None,
converter=optional_c(float),
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the kernel lengthscale."""

Expand All @@ -134,7 +144,9 @@ class PolynomialKernel(Kernel):
"""An optional prior on the kernel offset."""

offset_initial_value: float | None = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
default=None,
converter=optional_c(float),
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the kernel offset."""

Expand All @@ -160,7 +172,9 @@ class RBFKernel(Kernel):
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: float | None = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
default=None,
converter=optional_c(float),
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the kernel lengthscale."""

Expand All @@ -178,7 +192,9 @@ class RFFKernel(Kernel):
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: float | None = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
default=None,
converter=optional_c(float),
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the kernel lengthscale."""

Expand All @@ -193,6 +209,8 @@ class RQKernel(Kernel):
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: float | None = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
default=None,
converter=optional_c(float),
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the kernel lengthscale."""
16 changes: 12 additions & 4 deletions baybe/kernels/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from attrs import define, field
from attrs.converters import optional as optional_c
from attrs.validators import deep_iterable, instance_of
from attrs.validators import deep_iterable, gt, instance_of, min_len
from attrs.validators import optional as optional_v

from baybe.kernels.base import Kernel
Expand All @@ -25,7 +25,9 @@ class ScaleKernel(Kernel):
"""An optional prior on the output scale."""

outputscale_initial_value: float | None = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
default=None,
converter=optional_c(float),
validator=optional_v([finite_float, gt(0.0)]),
)
"""An optional initial value for the output scale."""

Expand All @@ -48,7 +50,10 @@ class AdditiveKernel(Kernel):
"""A kernel representing the sum of a collection of base kernels."""

base_kernels: tuple[Kernel, ...] = field(
converter=tuple, validator=deep_iterable(member_validator=instance_of(Kernel))
converter=tuple,
validator=deep_iterable(
member_validator=instance_of(Kernel), iterable_validator=min_len(2)
),
)
"""The individual kernels to be summed."""

Expand All @@ -63,7 +68,10 @@ class ProductKernel(Kernel):
"""A kernel representing the product of a collection of base kernels."""

base_kernels: tuple[Kernel, ...] = field(
converter=tuple, validator=deep_iterable(member_validator=instance_of(Kernel))
converter=tuple,
validator=deep_iterable(
member_validator=instance_of(Kernel), iterable_validator=min_len(2)
),
)
"""The individual kernels to be multiplied."""

Expand Down
6 changes: 4 additions & 2 deletions baybe/priors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
unstructure_base,
)
from baybe.serialization.mixin import SerialMixin
from baybe.utils.basic import filter_attributes
from baybe.utils.basic import match_attributes, set_default_torch_dtype


@define(frozen=True)
Expand All @@ -21,8 +21,10 @@ def to_gpytorch(self, *args, **kwargs):
"""Create the gpytorch representation of the prior."""
import gpytorch.priors

set_default_torch_dtype()

prior_cls = getattr(gpytorch.priors, self.__class__.__name__)
fields_dict = filter_attributes(object=self, callable_=prior_cls.__init__)
fields_dict = match_attributes(self, prior_cls.__init__)[0]

# Update kwargs to contain class-specific attributes
kwargs.update(fields_dict)
Expand Down
Loading
Loading