Skip to content

Commit

Permalink
fix(composer): decouple ConfigGroup logic from Composer
Browse files Browse the repository at this point in the history
  • Loading branch information
entelecheia committed Aug 2, 2023
1 parent 542445b commit 63144ff
Showing 1 changed file with 102 additions and 55 deletions.
157 changes: 102 additions & 55 deletions src/hyfi/composer/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import re
from enum import Enum
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Union

import hydra
from hydra.core.global_hydra import GlobalHydra
Expand Down Expand Up @@ -43,6 +43,94 @@ class SpecialKeys(str, Enum):
WITH = "_with_"


class ConfigGroup(BaseModel):
config_group: str = ""
_group_override_: str = ""
_group_key_: str = ""
_group_value_: str = ""
_global_package_: bool = False

@property
def group_override(self):
if not self._group_override_:
self._split_config_group()
return self._group_override_

@property
def group_key(self):
if not self._group_key_:
self._split_config_group()
return self._group_key_

@property
def group_value(self):
if not self._group_value_:
self._split_config_group()
return self._group_value_

@property
def global_package(self):
return self._global_package_

def get_overrides(
self,
overrides: Optional[List[str]] = None,
root_config_name: Optional[str] = None,
config_module: Optional[str] = None,
) -> List[str]:
if overrides is None:
overrides = []
if self.group_key in global_hyfi.global_package_list:
self._global_package_ = True
# If group_key and group_value are specified in the configuration file.
if self.group_key and self.group_value:
config = Composer.hydra_compose(
root_config_name=root_config_name,
config_module=config_module,
overrides=overrides,
)
config = Composer.select(
config,
key=self.group_key,
default=None,
throw_on_missing=False,
throw_on_resolution_failure=False,
)
override = (
self.group_override if config is not None else f"+{self.group_override}"
)
# Add override to overrides list.
if override:
if overrides:
overrides.append(override)
else:
overrides = [override]

return overrides

def _split_config_group(self):
config_group = self.config_group
if config_group:
group_ = config_group.split("=")
# group_key group_value group_key group_value group_key group_value default
if len(group_) == 2:
group_key, group_value = group_
else:
group_key = group_[0]
group_value = global_hyfi.hydra_default_config_group_value
# remove leading slash
if group_key.startswith("/"):
group_key = group_key[1:]
config_group = f"{group_key}={group_value}"
else:
group_key = ""
group_value = ""
config_group = ""
self._group_override_ = config_group
self._group_key_ = group_key
self._group_value_ = group_value


class Composer(BaseModel, CONFs):
"""
Compose a configuration by applying overrides
Expand Down Expand Up @@ -202,10 +290,10 @@ def hydra_compose(
is_initialized = GlobalHydra.instance().is_initialized() # type: ignore
config_module = config_module or global_hyfi.config_module
plugins = plugins or global_hyfi.plugins
logger.debug("config_module: %s", config_module)
# logger.debug("config_module: %s", config_module)
if is_initialized:
# Hydra is already initialized.
logger.debug("Hydra is already initialized")
# logger.debug("Hydra is already initialized")
cfg = hydra.compose(config_name=root_config_name, overrides=overrides)
else:
with hyfi_hydra.initialize_config(
Expand All @@ -217,25 +305,6 @@ def hydra_compose(
cfg = hydra.compose(config_name=root_config_name, overrides=overrides)
return cfg

@staticmethod
def split_config_group(
config_group: Optional[str] = None,
) -> Tuple[str, str, str]:
if config_group:
group_ = config_group.split("=")
# group_key group_value group_key group_value group_key group_value default
if len(group_) == 2:
group_key, group_value = group_
else:
group_key = group_[0]
group_value = global_hyfi.hydra_default_config_group_value
config_group = f"{group_key}={group_value}"
else:
group_key = ""
group_value = ""
config_group = ""
return config_group, group_key, group_value

@staticmethod
def _compose_as_dict(
config_group: Optional[str] = None,
Expand Down Expand Up @@ -310,35 +379,12 @@ def _compose_internal(
if isinstance(config_data, DictConfig):
logger.debug("returning config_data without composing")
return config_data
# Set overrides to the empty list if None
if overrides is None:
overrides = []
# Set the group key and value of the config group.
config_group, group_key, group_value = Composer.split_config_group(config_group)
if group_key in global_hyfi.global_package_list:
global_package = True
# If group_key and group_value are specified in the configuration file.
if group_key and group_value:
# Initialize hydra configuration module.
cfg = Composer.hydra_compose(
root_config_name=root_config_name,
config_module=config_module,
overrides=overrides,
)
cfg = Composer.select(
cfg,
key=group_key,
default=None,
throw_on_missing=False,
throw_on_resolution_failure=False,
)
override = config_group if cfg is not None else f"+{config_group}"
# Add override to overrides list.
if isinstance(override, str):
if overrides:
overrides.append(override)
else:
overrides = [override]
cg = ConfigGroup(config_group=config_group)
overrides = cg.get_overrides(
overrides=overrides,
root_config_name=root_config_name,
config_module=config_module,
)

logger.debug(f"compose config with overrides: {overrides}")
# Initialize hydra and return the configuration.
Expand All @@ -348,18 +394,19 @@ def _compose_internal(
overrides=overrides,
)
# Add config group overrides to overrides list.
if group_key and not global_package:
global_package = global_package or cg.global_package
if cg.group_key and not global_package:
group_overrides: List[str] = []
group_cfg = Composer.select(
cfg,
key=group_key,
key=cg.group_key,
default=None,
throw_on_missing=False,
throw_on_resolution_failure=False,
)
if config_data and group_cfg:
group_overrides.extend(
f"{group_key}.{k}={v}"
f"{cg.group_key}.{k}={v}"
for k, v in config_data.items()
if isinstance(v, (str, int, float, bool)) and k in group_cfg
)
Expand All @@ -374,7 +421,7 @@ def _compose_internal(
# Select the group_key from the configuration.
cfg = Composer.select(
cfg,
key=group_key,
key=cg.group_key,
default=None,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
Expand Down

0 comments on commit 63144ff

Please sign in to comment.