diff --git a/src/hyfi/core/hydra.py b/src/hyfi/core/hydra.py index 2e4ad0ea..f309fc7a 100644 --- a/src/hyfi/core/hydra.py +++ b/src/hyfi/core/hydra.py @@ -11,7 +11,11 @@ from hydra.core.singleton import Singleton from hydra.errors import HydraException -from hyfi.core import __config_module__ +from hyfi.core import __config_module_path__, __config_path__ +from hyfi.utils.logging import LOGGING +from hyfi.utils.packages import PKGs + +logger = LOGGING.getLogger(__name__) def get_gh_backup() -> Any: @@ -28,6 +32,24 @@ def restore_gh_from_backup(_gh_backup: Any) -> Any: Singleton._instances[GlobalHydra] = _gh_backup +def get_caller_config_module_path( + config_path: Optional[str] = __config_path__, +) -> Optional[str]: + """Returns the path to the caller module's config folder""" + caller_module_name = PKGs.get_caller_module_name() + config_module = caller_module_name.split(".")[0] + config_module_path = f"{config_module}.{config_path}" + if config_module_path == __config_module_path__: + return config_module_path + # check if the config module is importable + try: + __import__(config_module_path) + return config_module_path + except ImportError: + logger.info("Config module not found: %s", config_module_path) + return None + + _UNSPECIFIED_: Any = object() @@ -86,11 +108,16 @@ def create_config_search_path( search_path = ConfigSearchPathImpl() search_path.append("hydra", "pkg://hydra.conf") - if config_module is not None: + if config_module: search_path.append("main", f"pkg://{config_module}") - - if config_module != __config_module__: - search_path.append("hyfi", f"pkg://{__config_module__}") + caller_config_module = get_caller_config_module_path() + if caller_config_module: + search_path.append("caller", f"pkg://{caller_config_module}") + if ( + config_module != __config_module_path__ + and caller_config_module != __config_module_path__ + ): + search_path.append("hyfi", f"pkg://{__config_module_path__}") if search_path_dir is not None and os.path.isdir(search_path_dir): search_path.append("hyfi", f"file://{search_path_dir}")