diff --git a/src/hyfi/composer/composer.py b/src/hyfi/composer/composer.py index 5cd0d403..555a3f2f 100644 --- a/src/hyfi/composer/composer.py +++ b/src/hyfi/composer/composer.py @@ -5,7 +5,7 @@ import os import re from enum import Enum -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Union import hydra from hydra.core.global_hydra import GlobalHydra @@ -43,6 +43,94 @@ class SpecialKeys(str, Enum): WITH = "_with_" +class ConfigGroup(BaseModel): + config_group: str = "" + _group_override_: str = "" + _group_key_: str = "" + _group_value_: str = "" + _global_package_: bool = False + + @property + def group_override(self): + if not self._group_override_: + self._split_config_group() + return self._group_override_ + + @property + def group_key(self): + if not self._group_key_: + self._split_config_group() + return self._group_key_ + + @property + def group_value(self): + if not self._group_value_: + self._split_config_group() + return self._group_value_ + + @property + def global_package(self): + return self._global_package_ + + def get_overrides( + self, + overrides: Optional[List[str]] = None, + root_config_name: Optional[str] = None, + config_module: Optional[str] = None, + ) -> List[str]: + if overrides is None: + overrides = [] + if self.group_key in global_hyfi.global_package_list: + self._global_package_ = True + # If group_key and group_value are specified in the configuration file. + if self.group_key and self.group_value: + config = Composer.hydra_compose( + root_config_name=root_config_name, + config_module=config_module, + overrides=overrides, + ) + config = Composer.select( + config, + key=self.group_key, + default=None, + throw_on_missing=False, + throw_on_resolution_failure=False, + ) + override = ( + self.group_override if config is not None else f"+{self.group_override}" + ) + # Add override to overrides list. + if override: + if overrides: + overrides.append(override) + else: + overrides = [override] + + return overrides + + def _split_config_group(self): + config_group = self.config_group + if config_group: + group_ = config_group.split("=") + # group_key group_value group_key group_value group_key group_value default + if len(group_) == 2: + group_key, group_value = group_ + else: + group_key = group_[0] + group_value = global_hyfi.hydra_default_config_group_value + # remove leading slash + if group_key.startswith("/"): + group_key = group_key[1:] + config_group = f"{group_key}={group_value}" + else: + group_key = "" + group_value = "" + config_group = "" + self._group_override_ = config_group + self._group_key_ = group_key + self._group_value_ = group_value + + class Composer(BaseModel, CONFs): """ Compose a configuration by applying overrides @@ -202,10 +290,10 @@ def hydra_compose( is_initialized = GlobalHydra.instance().is_initialized() # type: ignore config_module = config_module or global_hyfi.config_module plugins = plugins or global_hyfi.plugins - logger.debug("config_module: %s", config_module) + # logger.debug("config_module: %s", config_module) if is_initialized: # Hydra is already initialized. - logger.debug("Hydra is already initialized") + # logger.debug("Hydra is already initialized") cfg = hydra.compose(config_name=root_config_name, overrides=overrides) else: with hyfi_hydra.initialize_config( @@ -217,25 +305,6 @@ def hydra_compose( cfg = hydra.compose(config_name=root_config_name, overrides=overrides) return cfg - @staticmethod - def split_config_group( - config_group: Optional[str] = None, - ) -> Tuple[str, str, str]: - if config_group: - group_ = config_group.split("=") - # group_key group_value group_key group_value group_key group_value default - if len(group_) == 2: - group_key, group_value = group_ - else: - group_key = group_[0] - group_value = global_hyfi.hydra_default_config_group_value - config_group = f"{group_key}={group_value}" - else: - group_key = "" - group_value = "" - config_group = "" - return config_group, group_key, group_value - @staticmethod def _compose_as_dict( config_group: Optional[str] = None, @@ -310,35 +379,12 @@ def _compose_internal( if isinstance(config_data, DictConfig): logger.debug("returning config_data without composing") return config_data - # Set overrides to the empty list if None - if overrides is None: - overrides = [] - # Set the group key and value of the config group. - config_group, group_key, group_value = Composer.split_config_group(config_group) - if group_key in global_hyfi.global_package_list: - global_package = True - # If group_key and group_value are specified in the configuration file. - if group_key and group_value: - # Initialize hydra configuration module. - cfg = Composer.hydra_compose( - root_config_name=root_config_name, - config_module=config_module, - overrides=overrides, - ) - cfg = Composer.select( - cfg, - key=group_key, - default=None, - throw_on_missing=False, - throw_on_resolution_failure=False, - ) - override = config_group if cfg is not None else f"+{config_group}" - # Add override to overrides list. - if isinstance(override, str): - if overrides: - overrides.append(override) - else: - overrides = [override] + cg = ConfigGroup(config_group=config_group) + overrides = cg.get_overrides( + overrides=overrides, + root_config_name=root_config_name, + config_module=config_module, + ) logger.debug(f"compose config with overrides: {overrides}") # Initialize hydra and return the configuration. @@ -348,18 +394,19 @@ def _compose_internal( overrides=overrides, ) # Add config group overrides to overrides list. - if group_key and not global_package: + global_package = global_package or cg.global_package + if cg.group_key and not global_package: group_overrides: List[str] = [] group_cfg = Composer.select( cfg, - key=group_key, + key=cg.group_key, default=None, throw_on_missing=False, throw_on_resolution_failure=False, ) if config_data and group_cfg: group_overrides.extend( - f"{group_key}.{k}={v}" + f"{cg.group_key}.{k}={v}" for k, v in config_data.items() if isinstance(v, (str, int, float, bool)) and k in group_cfg ) @@ -374,7 +421,7 @@ def _compose_internal( # Select the group_key from the configuration. cfg = Composer.select( cfg, - key=group_key, + key=cg.group_key, default=None, throw_on_missing=throw_on_missing, throw_on_resolution_failure=throw_on_resolution_failure,