Skip to content

Commit

Permalink
Handle instance attributes that are not relevant for matching
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Jul 15, 2024
1 parent 45cc0bc commit 143da7c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
20 changes: 18 additions & 2 deletions baybe/acquisition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import ClassVar

import pandas as pd
from attrs import define
from attrs import define, fields

from baybe.searchspace import SearchSpace
from baybe.serialization.core import (
Expand Down Expand Up @@ -45,9 +45,25 @@ def to_botorch(
"""Create the botorch-ready representation of the function."""
import botorch.acquisition as botorch_acqf_module

from baybe.acquisition.acqfs import qNegIntegratedPosteriorVariance

# Retrieve corresponding botorch class
acqf_cls = getattr(botorch_acqf_module, self.__class__.__name__)
params_dict = match_attributes(self, acqf_cls.__init__)[0]

# Match relevant attributes
flds = fields(qNegIntegratedPosteriorVariance)
ignore = (
(
flds.sampling_n_points.name,
flds.sampling_method.name,
flds.sampling_fraction.name,
)
if isinstance(self, qNegIntegratedPosteriorVariance)
else ()
)
params_dict = match_attributes(self, acqf_cls.__init__, ignore=ignore)[0]

# Collect remaining (context-specific) parameters
signature_params = signature(acqf_cls).parameters
additional_params = {}
if "model" in signature_params:
Expand Down
10 changes: 6 additions & 4 deletions baybe/utils/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
import inspect
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Collection, Iterable, Sequence
from dataclasses import dataclass
from typing import Any, TypeVar

Expand Down Expand Up @@ -128,6 +128,7 @@ def match_attributes(
callable_: Callable,
/,
strict: bool = True,
ignore: Collection[str] = (),
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Find the attributes of an object that match with a given callable signature.
Expand All @@ -139,8 +140,9 @@ def match_attributes(
callable_: The callable against whose signature the attributes are to be
matched.
strict: If ``True``, an error is raised when the object has attributes that
are not found in the callable signature. If ``False``, these attributes are
returned separately.
are not found in the callable signature (see also ``ignore``).
If ``False``, these attributes are returned separately.
ignore: A collection of attributes names that are to be ignored during matching.
Raises:
ValueError: If applied to a non-attrs object.
Expand All @@ -155,7 +157,7 @@ def match_attributes(
f"'{match_attributes.__name__}' only works with attrs objects."
)
# Get attribute/parameter sets
set_object = set(asdict(object))
set_object = set(asdict(object)) - set(ignore)
set_callable = set(inspect.signature(callable_).parameters)

# Match
Expand Down

0 comments on commit 143da7c

Please sign in to comment.