Skip to content

Commit

Permalink
feat(composer): add global package list support, add instantiate_conf…
Browse files Browse the repository at this point in the history
…ig method, add print_config method

refactor(composer): modify conditions for group overrides
  • Loading branch information
entelecheia committed Jul 21, 2023
1 parent 109b189 commit f98a679
Showing 1 changed file with 87 additions and 23 deletions.
110 changes: 87 additions & 23 deletions src/hyfi/composer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
LOGGING.setLogger(level)
logger = LOGGING.getLogger(__name__)

__global_package_list__: Set[str] = {"cmd", "mode", "workflow"}


class SpecialKeys(str, Enum):
"""Special keys in configs used by HyFI."""
Expand Down Expand Up @@ -280,6 +282,8 @@ def _compose(
overrides = []
# Set the group key and value of the config group.
config_group, group_key, group_value = Composer.split_config_group(config_group)
if group_key in __global_package_list__:
global_package = True
# If group_key and group_value are specified in the configuration file.
if group_key and group_value:
# Initialize hydra configuration module.
Expand Down Expand Up @@ -311,30 +315,30 @@ def _compose(
overrides=overrides,
)
# Add config group overrides to overrides list.
group_overrides: List[str] = []
group_cfg = Composer.select(
cfg,
key=group_key,
default=None,
throw_on_missing=False,
throw_on_resolution_failure=False,
)
if config_data and group_cfg:
group_overrides.extend(
f"{group_key}.{k}={v}"
for k, v in config_data.items()
if isinstance(v, (str, int, float, bool)) and k in group_cfg
)
if group_overrides:
overrides.extend(group_overrides)
cfg = Composer.hydra_compose(
root_config_name=root_config_name,
config_module=config_module,
overrides=overrides,
)

# Select the group_key from the configuration.
if group_key and not global_package:
group_overrides: List[str] = []
group_cfg = Composer.select(
cfg,
key=group_key,
default=None,
throw_on_missing=False,
throw_on_resolution_failure=False,
)
if config_data and group_cfg:
group_overrides.extend(
f"{group_key}.{k}={v}"
for k, v in config_data.items()
if isinstance(v, (str, int, float, bool)) and k in group_cfg
)
if group_overrides:
overrides.extend(group_overrides)
cfg = Composer.hydra_compose(
root_config_name=root_config_name,
config_module=config_module,
overrides=overrides,
)

# Select the group_key from the configuration.
cfg = Composer.select(
cfg,
key=group_key,
Expand Down Expand Up @@ -519,6 +523,66 @@ def viewsource(obj: Any):
"""
print(Composer.getsource(obj))

@staticmethod
def instantiate_config(
config_group: Union[str, None] = None,
overrides: Union[List[str], None] = None,
config_data: Union[Dict[str, Any], DictConfig, None] = None,
global_package: bool = False,
*args: Any,
**kwargs: Any,
) -> Any:
"""
Instantiates an object using the provided config group and overrides.
Args:
config_group: Name of the config group to compose (`config_group=name`)
overrides: List of config groups to apply overrides to (`overrides=["override_name"]`)
config_data: Keyword arguments to override config group values (will be converted to overrides of the form `config_group_name.key=value`)
global_package: If True, the config assumed to be a global package
args: Optional positional parameters pass-through
kwargs: Optional named parameters to override
parameters in the config object. Parameters not present
in the config objects are being passed as is to the target.
IMPORTANT: dataclasses instances in kwargs are interpreted as config
and cannot be used as passthrough
Returns:
if _target_ is a class name: the instantiated object
if _target_ is a callable: the return value of the call
"""
cfg = Composer._compose(
config_group=config_group,
overrides=overrides,
config_data=config_data,
global_package=global_package,
)
return Composer.instantiate(cfg, *args, **kwargs)

@staticmethod
def print_config(
config_group: Union[str, None] = None,
overrides: Union[List[str], None] = None,
config_data: Union[Dict[str, Any], DictConfig, None] = None,
global_package: bool = False,
):
"""
Print the configuration
Args:
config_group: Name of the config group to compose (`config_group=name`)
overrides: List of config groups to apply overrides to (`overrides=["override_name"]`)
config_data: Keyword arguments to override config group values (will be converted to overrides of the form `config_group_name.key=value`)
global_package: If True, the config assumed to be a global package
"""
cfg = Composer._compose(
config_group=config_group,
overrides=overrides,
config_data=config_data,
global_package=global_package,
)
Composer.print(cfg)


class BaseConfig(BaseModel):
"""
Expand Down

0 comments on commit f98a679

Please sign in to comment.