Skip to content

Commit

Permalink
Rework get_base_clases
Browse files Browse the repository at this point in the history
  • Loading branch information
AVHopp committed May 13, 2024
1 parent 65e9884 commit f151de1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion baybe/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def to_gpytorch(
# via the `gpytorch.kernels.Kernel` base class. Hence, it is not sufficient to
# just check the fields of the actual class, but also those of the base class.
kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__)
base_classes = get_base_classes(kernel_cls, include_class=True)
base_classes = get_base_classes(kernel_cls, abstract=True, include_class=True)
fields_dict = {}
for parent_class in base_classes:
fields_dict.update(
Expand Down
14 changes: 7 additions & 7 deletions baybe/utils/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,20 @@ def get_base_classes(
recursive: bool = True,
abstract: bool = False,
include_class: bool = False,
) -> list[type]:
) -> set[type]:
"""Return a list of base classes for the given class.
Args:
cls: The class to retrieve base classes for.
recursive: If ``True``, indirect base classes (i.e., base classes of base
classes) are included.
abstract: If `True`, abstract base classes are included.
include_class: If ``True``, the class itself is included.
include_class: If ``True``, the class itself is included. Note that this will
include the class under any circumstances, that is, even if it is abstract
and ``abstract``was set to ``False``.
Returns:
A list of base classes for the given class.
A set of base classes for the given class.
"""
from baybe.utils.boolean import is_abstract

Expand All @@ -82,12 +84,10 @@ def get_base_classes(

if recursive:
classes.extend(
get_base_classes(cls, abstract=abstract, include_class=False)
if base_class not in classes
else []
get_base_classes(base_class, abstract=abstract, include_class=False)
)

return classes
return set(classes)


def set_random_seed(seed: int):
Expand Down

0 comments on commit f151de1

Please sign in to comment.