Skip to content

Commit

Permalink
Refactor OmegaConfigLoader (#4100)
Browse files Browse the repository at this point in the history
Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>
  • Loading branch information
merelcht authored Aug 29, 2024
1 parent 080b265 commit f738dc8
Showing 1 changed file with 65 additions and 23 deletions.
88 changes: 65 additions & 23 deletions kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import mimetypes
import typing
from collections.abc import KeysView
from enum import Enum, auto
from pathlib import Path
from typing import Any, Callable, Iterable

Expand All @@ -26,6 +27,17 @@
_NO_VALUE = object()


class MergeStrategies(Enum):
SOFT = auto()
DESTRUCTIVE = auto()


MERGING_IMPLEMENTATIONS = {
MergeStrategies.SOFT: "_soft_merge",
MergeStrategies.DESTRUCTIVE: "_destructive_merge",
}


class OmegaConfigLoader(AbstractConfigLoader):
"""Recursively scan directories (config paths) contained in ``conf_source`` for
configuration files with a ``yaml``, ``yml`` or ``json`` extension, load and merge
Expand Down Expand Up @@ -131,18 +143,9 @@ def __init__( # noqa: PLR0913
self._register_new_resolvers(custom_resolvers)
# Register globals resolver
self._register_globals_resolver()
file_mimetype, _ = mimetypes.guess_type(conf_source)
if file_mimetype == "application/x-tar":
self._protocol = "tar"
elif file_mimetype in (
"application/zip",
"application/x-zip-compressed",
"application/zip-compressed",
):
self._protocol = "zip"
else:
self._protocol = "file"
self._fs = fsspec.filesystem(protocol=self._protocol, fo=conf_source)

# Setup file system and protocol
self._fs, self._protocol = self._initialise_filesystem_and_protocol(conf_source)

super().__init__(
conf_source=conf_source,
Expand Down Expand Up @@ -220,6 +223,11 @@ def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912

# Load chosen env config
run_env = self.env or self.default_run_env

# Return if chosen env config is the same as base config to avoid loading the same config twice
if run_env == self.base_env:
return config # type: ignore[no-any-return]

if self._protocol == "file":
env_path = str(Path(self.conf_source) / run_env)
else:
Expand All @@ -236,16 +244,7 @@ def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912
else:
raise exc

merging_strategy = self.merge_strategy.get(key)
if merging_strategy == "soft":
resulting_config = self._soft_merge(config, env_config)
elif merging_strategy == "destructive" or not merging_strategy:
resulting_config = self._destructive_merge(config, env_config, env_path)
else:
raise ValueError(
f"Merging strategy {merging_strategy} not supported. The accepted merging "
f"strategies are `soft` and `destructive`."
)
resulting_config = self._merge_configs(config, env_config, key, env_path)

if not processed_files and key != "globals":
raise MissingConfigException(
Expand Down Expand Up @@ -355,6 +354,47 @@ def load_and_merge_dir_config(
if not k.startswith("_")
}

@staticmethod
def _initialise_filesystem_and_protocol(
conf_source: str,
) -> tuple[fsspec.AbstractFileSystem, str]:
"""Set up the file system based on the file type detected in conf_source."""
file_mimetype, _ = mimetypes.guess_type(conf_source)
if file_mimetype == "application/x-tar":
protocol = "tar"
elif file_mimetype in (
"application/zip",
"application/x-zip-compressed",
"application/zip-compressed",
):
protocol = "zip"
else:
protocol = "file"
fs = fsspec.filesystem(protocol=protocol, fo=conf_source)
return fs, protocol

def _merge_configs(
self,
config: dict[str, Any],
env_config: dict[str, Any],
key: str,
env_path: str,
) -> Any:
merging_strategy = self.merge_strategy.get(key, "destructive")
try:
strategy = MergeStrategies[merging_strategy.upper()]

# Get the corresponding merge function and call it
merge_function_name = MERGING_IMPLEMENTATIONS[strategy]
merge_function = getattr(self, merge_function_name)
return merge_function(config, env_config, env_path)
except KeyError:
allowed_strategies = [strategy.name.lower() for strategy in MergeStrategies]
raise ValueError(
f"Merging strategy {merging_strategy} not supported. The accepted merging "
f"strategies are {allowed_strategies}."
)

def _get_all_keys(self, cfg: Any, parent_key: str = "") -> set[str]:
keys: set[str] = set()

Expand Down Expand Up @@ -499,7 +539,9 @@ def _destructive_merge(
return config

@staticmethod
def _soft_merge(config: dict[str, Any], env_config: dict[str, Any]) -> Any:
def _soft_merge(
config: dict[str, Any], env_config: dict[str, Any], env_path: str | None = None
) -> Any:
# Soft merge the two env dirs. The chosen env will override base if keys clash.
return OmegaConf.to_container(OmegaConf.merge(config, env_config))

Expand Down

0 comments on commit f738dc8

Please sign in to comment.