diff --git a/hsf/_cache.py b/hsf/_cache.py new file mode 100644 index 0000000..9df77f5 --- /dev/null +++ b/hsf/_cache.py @@ -0,0 +1,51 @@ +"""Handling of cache files.""" + +import logging +import tempfile +from shutil import rmtree + +from rich.logging import RichHandler + +FORMAT = "%(message)s" +logging.basicConfig(level="NOTSET", + format=FORMAT, + datefmt="[%X]", + handlers=[RichHandler()]) + +log = logging.getLogger(__name__) + + +def handle_cache(func: callable): + """ + Decorator to handle cache files. + Fixes https://github.com/ANTsX/ANTsPy/issues/117 + + Args: + func (callable): Function to decorate. + + Returns: + callable: Decorated function. + """ + + def cache(*args, **kwargs): + """ + Wrapper for the function. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + callable: Function to decorate. + """ + cache_dir = tempfile.mkdtemp() + + log.debug(f"Caching in ANTs' files in {cache_dir}") + + kwargs["outprefix"] = cache_dir + try: + return func(*args, **kwargs) + finally: + rmtree(cache_dir) + + return cache diff --git a/hsf/roiloc_wrapper.py b/hsf/roiloc_wrapper.py index 816664d..9140dc0 100644 --- a/hsf/roiloc_wrapper.py +++ b/hsf/roiloc_wrapper.py @@ -7,6 +7,8 @@ from rich.logging import RichHandler from roiloc.locator import RoiLocator +from ._cache import handle_cache + FORMAT = "%(message)s" logging.basicConfig(level="NOTSET", format=FORMAT, @@ -58,9 +60,11 @@ def get_mri(mri: PosixPath, mask_pattern: Optional[str] = None) -> tuple: return ants.image_read(str(mri)), mask +@handle_cache def get_hippocampi(mri: ants.ANTsImage, roiloc_cfg: DictConfig, - mask: Optional[ants.ANTsImage] = None) -> tuple: + mask: Optional[ants.ANTsImage] = None, + outprefix: str = "") -> tuple: """ Locate right and left hippocampi from a given mri. @@ -69,6 +73,7 @@ def get_hippocampi(mri: ants.ANTsImage, roiloc_cfg (DictConfig): Roiloc configuration. See `github.com/clementpoiret/ROILoc` for more information. mask (ants.ANTsImage, optional): Loaded mask. + outprefix (str, optional): Prefix for ANTs temp files. Returns: RoiLocator: fitted roilocator. @@ -76,8 +81,7 @@ def get_hippocampi(mri: ants.ANTsImage, left_mri (ants.ANTsImage): Left hippocampus. """ locator = RoiLocator(**roiloc_cfg, mask=mask) - - right_mri, left_mri = locator.fit_transform(mri) + right_mri, left_mri = locator.fit_transform(mri, outprefix=outprefix) return locator, right_mri, left_mri diff --git a/pyproject.toml b/pyproject.toml index 9bef26a..9763f0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ deepsparse_support = 'hsf.engines:print_deepsparse_support' [tool.poetry.dependencies] python = "^3.7.1" torchio = "^0.18.56" -roiloc = "^0.2.7" +roiloc = "^0.2.8" onnxruntime = "^1.8.1" hydra-core = "^1.1.1" wget = "^3.2"