Skip to content

Commit

Permalink
feat(hyfi/core/hydra): add function get_caller_config_module_path
Browse files Browse the repository at this point in the history
  • Loading branch information
entelecheia committed Jul 22, 2023
1 parent ef66024 commit 074c538
Showing 1 changed file with 32 additions and 5 deletions.
37 changes: 32 additions & 5 deletions src/hyfi/core/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from hydra.core.singleton import Singleton
from hydra.errors import HydraException

from hyfi.core import __config_module__
from hyfi.core import __config_module_path__, __config_path__
from hyfi.utils.logging import LOGGING
from hyfi.utils.packages import PKGs

logger = LOGGING.getLogger(__name__)


def get_gh_backup() -> Any:
Expand All @@ -28,6 +32,24 @@ def restore_gh_from_backup(_gh_backup: Any) -> Any:
Singleton._instances[GlobalHydra] = _gh_backup


def get_caller_config_module_path(
config_path: Optional[str] = __config_path__,
) -> Optional[str]:
"""Returns the path to the caller module's config folder"""
caller_module_name = PKGs.get_caller_module_name()
config_module = caller_module_name.split(".")[0]
config_module_path = f"{config_module}.{config_path}"
if config_module_path == __config_module_path__:
return config_module_path
# check if the config module is importable
try:
__import__(config_module_path)
return config_module_path
except ImportError:
logger.info("Config module not found: %s", config_module_path)
return None


_UNSPECIFIED_: Any = object()


Expand Down Expand Up @@ -86,11 +108,16 @@ def create_config_search_path(
search_path = ConfigSearchPathImpl()
search_path.append("hydra", "pkg://hydra.conf")

if config_module is not None:
if config_module:
search_path.append("main", f"pkg://{config_module}")

if config_module != __config_module__:
search_path.append("hyfi", f"pkg://{__config_module__}")
caller_config_module = get_caller_config_module_path()
if caller_config_module:
search_path.append("caller", f"pkg://{caller_config_module}")
if (
config_module != __config_module_path__
and caller_config_module != __config_module_path__
):
search_path.append("hyfi", f"pkg://{__config_module_path__}")

if search_path_dir is not None and os.path.isdir(search_path_dir):
search_path.append("hyfi", f"file://{search_path_dir}")
Expand Down

0 comments on commit 074c538

Please sign in to comment.