From cd9233444d6bfb21821acc966e946261cc02a052 Mon Sep 17 00:00:00 2001 From: Young Joon Lee Date: Tue, 27 Jun 2023 12:53:40 +0900 Subject: [PATCH] docs(composer): clarify comments and docstrings --- src/hyfi/composer/__init__.py | 106 ++++++++++++++++++++++++++++++---- 1 file changed, 95 insertions(+), 11 deletions(-) diff --git a/src/hyfi/composer/__init__.py b/src/hyfi/composer/__init__.py index 214f322b..f6fa72af 100644 --- a/src/hyfi/composer/__init__.py +++ b/src/hyfi/composer/__init__.py @@ -29,7 +29,7 @@ class SpecialKeys(str, Enum): - """Special keys in configs used by hyfi.""" + """Special keys in configs used by HyFI.""" CALL = "_call_" CONFIG = "_config_" @@ -131,10 +131,16 @@ def compose( @property def config(self) -> DictConfig: + """ + Returns the composed configuration. + """ return self.__cfg__ @property def config_as_dict(self) -> Dict: + """ + Return the configuration as a dictionary. + """ return Composer.to_dict(self.__cfg__) def __call__( @@ -425,6 +431,17 @@ def _compose( @staticmethod def print(cfg: Any, resolve: bool = True, **kwargs): + """ + Prints the configuration object in a human-readable format. + + Args: + cfg (Any): The configuration object to print. + resolve (bool, optional): Whether to resolve the configuration object before printing. Defaults to True. + **kwargs: Additional keyword arguments to pass to the pprint.pprint function. + + Returns: + None + """ import pprint if Composer.is_config(cfg): @@ -436,27 +453,67 @@ def print(cfg: Any, resolve: bool = True, **kwargs): print(cfg) @staticmethod - def is_config( - cfg: Any, - ): + def is_config(cfg: Any): + """ + Determines whether the input object is a valid configuration object. + + Args: + cfg (Any): The object to check. + + Returns: + bool: True if the object is a valid configuration object, False otherwise. + """ return isinstance(cfg, (DictConfig, dict)) @staticmethod - def is_list( - cfg: Any, - ): + def is_list(cfg: Any): + """ + Determines whether the input object is a valid list configuration object. + + Args: + cfg (Any): The object to check. + + Returns: + bool: True if the object is a valid list configuration object, False otherwise. + """ return isinstance(cfg, (ListConfig, list)) @staticmethod def is_instantiatable(cfg: Any): + """ + Determines whether the input configuration object is instantiatable. + + Args: + cfg (Any): The configuration object to check. + + Returns: + bool: True if the configuration object is instantiatable, False otherwise. + """ return Composer.is_config(cfg) and SpecialKeys.TARGET in cfg @staticmethod def load(file_: Union[str, Path, IO[Any]]) -> Union[DictConfig, ListConfig]: + """ + Load a configuration file and return a configuration object. + + Args: + file_ (Union[str, Path, IO[Any]]): The path to the configuration file or a file-like object. + + Returns: + Union[DictConfig, ListConfig]: The configuration object. + """ return OmegaConf.load(file_) @staticmethod def save(config: Any, f: Union[str, Path, IO[Any]], resolve: bool = False) -> None: + """ + Save a configuration object to a file. + + Args: + config (Any): The configuration object to save. + f (Union[str, Path, IO[Any]]): The path to the file or a file-like object. + resolve (bool, optional): Whether to resolve the configuration object before saving. Defaults to False. + """ os.makedirs(os.path.dirname(str(f)), exist_ok=True) OmegaConf.save(config, f, resolve=resolve) @@ -470,6 +527,18 @@ def save_json( encoding="utf-8", **kwargs, ): + """ + Save a dictionary to a JSON file. + + Args: + json_dict (dict): The dictionary to save. + f (Union[str, Path, IO[Any]]): The path to the file or a file-like object. + indent (int, optional): The number of spaces to use for indentation. Defaults to 4. + ensure_ascii (bool, optional): Whether to escape non-ASCII characters. Defaults to False. + default (Any, optional): A function to convert non-serializable objects. Defaults to None. + encoding (str, optional): The encoding to use. Defaults to "utf-8". + **kwargs: Additional arguments to pass to json.dump(). + """ f = str(f) os.makedirs(os.path.dirname(f), exist_ok=True) with open(f, "w", encoding=encoding) as f: @@ -484,6 +553,17 @@ def save_json( @staticmethod def load_json(f: Union[str, Path, IO[Any]], encoding="utf-8", **kwargs) -> dict: + """ + Load a JSON file into a dictionary. + + Args: + f (Union[str, Path, IO[Any]]): The path to the file or a file-like object. + encoding (str, optional): The encoding to use. Defaults to "utf-8". + **kwargs: Additional arguments to pass to json.load(). + + Returns: + dict: The dictionary loaded from the JSON file. + """ f = str(f) with open(f, "r", encoding=encoding) as f: return json.load(f, **kwargs) @@ -491,10 +571,14 @@ def load_json(f: Union[str, Path, IO[Any]], encoding="utf-8", **kwargs) -> dict: @staticmethod def update(_dict: Mapping[str, Any], _overrides: Mapping[str, Any]) -> Mapping: """ - Update a dictionary with overrides - :param _dict: dictionary to update - :param _overrides: dictionary with overrides - :return: updated dictionary + Update a dictionary with overrides. + + Args: + _dict (Mapping[str, Any]): The dictionary to update. + _overrides (Mapping[str, Any]): The dictionary with overrides. + + Returns: + Mapping: The updated dictionary. """ for k, v in _overrides.items(): if isinstance(v, collections.abc.Mapping):