Skip to content

Commit

Permalink
docs(composer): clarify comments and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
entelecheia committed Jun 27, 2023
1 parent 3099e84 commit cd92334
Showing 1 changed file with 95 additions and 11 deletions.
106 changes: 95 additions & 11 deletions src/hyfi/composer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


class SpecialKeys(str, Enum):
"""Special keys in configs used by hyfi."""
"""Special keys in configs used by HyFI."""

CALL = "_call_"
CONFIG = "_config_"
Expand Down Expand Up @@ -131,10 +131,16 @@ def compose(

@property
def config(self) -> DictConfig:
"""
Returns the composed configuration.
"""
return self.__cfg__

@property
def config_as_dict(self) -> Dict:
"""
Return the configuration as a dictionary.
"""
return Composer.to_dict(self.__cfg__)

def __call__(
Expand Down Expand Up @@ -425,6 +431,17 @@ def _compose(

@staticmethod
def print(cfg: Any, resolve: bool = True, **kwargs):
"""
Prints the configuration object in a human-readable format.
Args:
cfg (Any): The configuration object to print.
resolve (bool, optional): Whether to resolve the configuration object before printing. Defaults to True.
**kwargs: Additional keyword arguments to pass to the pprint.pprint function.
Returns:
None
"""
import pprint

if Composer.is_config(cfg):
Expand All @@ -436,27 +453,67 @@ def print(cfg: Any, resolve: bool = True, **kwargs):
print(cfg)

@staticmethod
def is_config(
cfg: Any,
):
def is_config(cfg: Any):
"""
Determines whether the input object is a valid configuration object.
Args:
cfg (Any): The object to check.
Returns:
bool: True if the object is a valid configuration object, False otherwise.
"""
return isinstance(cfg, (DictConfig, dict))

@staticmethod
def is_list(
cfg: Any,
):
def is_list(cfg: Any):
"""
Determines whether the input object is a valid list configuration object.
Args:
cfg (Any): The object to check.
Returns:
bool: True if the object is a valid list configuration object, False otherwise.
"""
return isinstance(cfg, (ListConfig, list))

@staticmethod
def is_instantiatable(cfg: Any):
"""
Determines whether the input configuration object is instantiatable.
Args:
cfg (Any): The configuration object to check.
Returns:
bool: True if the configuration object is instantiatable, False otherwise.
"""
return Composer.is_config(cfg) and SpecialKeys.TARGET in cfg

@staticmethod
def load(file_: Union[str, Path, IO[Any]]) -> Union[DictConfig, ListConfig]:
"""
Load a configuration file and return a configuration object.
Args:
file_ (Union[str, Path, IO[Any]]): The path to the configuration file or a file-like object.
Returns:
Union[DictConfig, ListConfig]: The configuration object.
"""
return OmegaConf.load(file_)

@staticmethod
def save(config: Any, f: Union[str, Path, IO[Any]], resolve: bool = False) -> None:
"""
Save a configuration object to a file.
Args:
config (Any): The configuration object to save.
f (Union[str, Path, IO[Any]]): The path to the file or a file-like object.
resolve (bool, optional): Whether to resolve the configuration object before saving. Defaults to False.
"""
os.makedirs(os.path.dirname(str(f)), exist_ok=True)
OmegaConf.save(config, f, resolve=resolve)

Expand All @@ -470,6 +527,18 @@ def save_json(
encoding="utf-8",
**kwargs,
):
"""
Save a dictionary to a JSON file.
Args:
json_dict (dict): The dictionary to save.
f (Union[str, Path, IO[Any]]): The path to the file or a file-like object.
indent (int, optional): The number of spaces to use for indentation. Defaults to 4.
ensure_ascii (bool, optional): Whether to escape non-ASCII characters. Defaults to False.
default (Any, optional): A function to convert non-serializable objects. Defaults to None.
encoding (str, optional): The encoding to use. Defaults to "utf-8".
**kwargs: Additional arguments to pass to json.dump().
"""
f = str(f)
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, "w", encoding=encoding) as f:
Expand All @@ -484,17 +553,32 @@ def save_json(

@staticmethod
def load_json(f: Union[str, Path, IO[Any]], encoding="utf-8", **kwargs) -> dict:
"""
Load a JSON file into a dictionary.
Args:
f (Union[str, Path, IO[Any]]): The path to the file or a file-like object.
encoding (str, optional): The encoding to use. Defaults to "utf-8".
**kwargs: Additional arguments to pass to json.load().
Returns:
dict: The dictionary loaded from the JSON file.
"""
f = str(f)
with open(f, "r", encoding=encoding) as f:
return json.load(f, **kwargs)

@staticmethod
def update(_dict: Mapping[str, Any], _overrides: Mapping[str, Any]) -> Mapping:
"""
Update a dictionary with overrides
:param _dict: dictionary to update
:param _overrides: dictionary with overrides
:return: updated dictionary
Update a dictionary with overrides.
Args:
_dict (Mapping[str, Any]): The dictionary to update.
_overrides (Mapping[str, Any]): The dictionary with overrides.
Returns:
Mapping: The updated dictionary.
"""
for k, v in _overrides.items():
if isinstance(v, collections.abc.Mapping):
Expand Down

0 comments on commit cd92334

Please sign in to comment.