From f151de1907e738467bfaccd51f1f331af73ce525 Mon Sep 17 00:00:00 2001 From: "Alexander V. Hopp" Date: Mon, 13 May 2024 13:10:07 +0200 Subject: [PATCH] Rework get_base_clases --- baybe/kernels/base.py | 2 +- baybe/utils/basic.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/baybe/kernels/base.py b/baybe/kernels/base.py index ee14b944a5..132210b237 100644 --- a/baybe/kernels/base.py +++ b/baybe/kernels/base.py @@ -39,7 +39,7 @@ def to_gpytorch( # via the `gpytorch.kernels.Kernel` base class. Hence, it is not sufficient to # just check the fields of the actual class, but also those of the base class. kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__) - base_classes = get_base_classes(kernel_cls, include_class=True) + base_classes = get_base_classes(kernel_cls, abstract=True, include_class=True) fields_dict = {} for parent_class in base_classes: fields_dict.update( diff --git a/baybe/utils/basic.py b/baybe/utils/basic.py index c13ccd3c22..df2b429b5b 100644 --- a/baybe/utils/basic.py +++ b/baybe/utils/basic.py @@ -59,7 +59,7 @@ def get_base_classes( recursive: bool = True, abstract: bool = False, include_class: bool = False, -) -> list[type]: +) -> set[type]: """Return a list of base classes for the given class. Args: @@ -67,10 +67,12 @@ def get_base_classes( recursive: If ``True``, indirect base classes (i.e., base classes of base classes) are included. abstract: If `True`, abstract base classes are included. - include_class: If ``True``, the class itself is included. + include_class: If ``True``, the class itself is included. Note that this will + include the class under any circumstances, that is, even if it is abstract + and ``abstract``was set to ``False``. Returns: - A list of base classes for the given class. + A set of base classes for the given class. """ from baybe.utils.boolean import is_abstract @@ -82,12 +84,10 @@ def get_base_classes( if recursive: classes.extend( - get_base_classes(cls, abstract=abstract, include_class=False) - if base_class not in classes - else [] + get_base_classes(base_class, abstract=abstract, include_class=False) ) - return classes + return set(classes) def set_random_seed(seed: int):