diff --git a/src/labthings/extensions.py b/src/labthings/extensions.py index 81384283..4861f998 100644 --- a/src/labthings/extensions.py +++ b/src/labthings/extensions.py @@ -3,6 +3,7 @@ import os import sys import traceback +import inspect from importlib import util from typing import Callable, Dict, List, Union @@ -269,15 +270,26 @@ def find_extensions_in_file(extension_path: str, module_name="extensions") -> li else: # TODO: Add documentation links to warnings if hasattr(mod, "LABTHINGS_EXTENSIONS"): - return [ - ext_class() - for ext_class in getattr(mod, "LABTHINGS_EXTENSIONS") - if issubclass(ext_class, BaseExtension) - ] + ext_objects = [] + for ext_element in getattr(mod, "LABTHINGS_EXTENSIONS"): + if inspect.isclass(ext_element) and issubclass( + ext_element, BaseExtension + ): + ext_objects.append(ext_element()) + elif isinstance(ext_element, BaseExtension): + logging.warning( + "%s: Extension instance passed instead of class. LABTHINGS_EXTENSIONS should contain classes, not instances.", + ext_element, + ) + ext_objects.append(ext_element) + else: + logging.error( + "Unsupported extension type %s. Skipping.", type(ext_element) + ) + return ext_objects elif hasattr(mod, "__extensions__"): logging.warning( - "Explicit extension list using the __extensions__ global is deprecated.", - "Please use LABTHINGS_EXTENSIONS instead.", + "Explicit extension list using the __extensions__ global is deprecated. Please use LABTHINGS_EXTENSIONS instead." ) return [ getattr(mod, ext_name) for ext_name in getattr(mod, "__extensions__")