Skip to content

Commit

Permalink
[CODE] handle_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
clementpoiret committed Mar 15, 2022
1 parent 1e337ce commit a9f370e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
51 changes: 51 additions & 0 deletions hsf/_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Handling of cache files."""

import logging
import tempfile
from shutil import rmtree

from rich.logging import RichHandler

FORMAT = "%(message)s"
logging.basicConfig(level="NOTSET",
format=FORMAT,
datefmt="[%X]",
handlers=[RichHandler()])

log = logging.getLogger(__name__)


def handle_cache(func: callable):
"""
Decorator to handle cache files.
Fixes https://github.com/ANTsX/ANTsPy/issues/117
Args:
func (callable): Function to decorate.
Returns:
callable: Decorated function.
"""

def cache(*args, **kwargs):
"""
Wrapper for the function.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
callable: Function to decorate.
"""
cache_dir = tempfile.mkdtemp()

log.debug(f"Caching in ANTs' files in {cache_dir}")

kwargs["outprefix"] = cache_dir
try:
return func(*args, **kwargs)
finally:
rmtree(cache_dir)

return cache
10 changes: 7 additions & 3 deletions hsf/roiloc_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from rich.logging import RichHandler
from roiloc.locator import RoiLocator

from ._cache import handle_cache

FORMAT = "%(message)s"
logging.basicConfig(level="NOTSET",
format=FORMAT,
Expand Down Expand Up @@ -58,9 +60,11 @@ def get_mri(mri: PosixPath, mask_pattern: Optional[str] = None) -> tuple:
return ants.image_read(str(mri)), mask


@handle_cache
def get_hippocampi(mri: ants.ANTsImage,
roiloc_cfg: DictConfig,
mask: Optional[ants.ANTsImage] = None) -> tuple:
mask: Optional[ants.ANTsImage] = None,
outprefix: str = "") -> tuple:
"""
Locate right and left hippocampi from a given mri.
Expand All @@ -69,15 +73,15 @@ def get_hippocampi(mri: ants.ANTsImage,
roiloc_cfg (DictConfig): Roiloc configuration.
See `github.com/clementpoiret/ROILoc` for more information.
mask (ants.ANTsImage, optional): Loaded mask.
outprefix (str, optional): Prefix for ANTs temp files.
Returns:
RoiLocator: fitted roilocator.
right_mri (ants.ANTsImage): Right hippocampus.
left_mri (ants.ANTsImage): Left hippocampus.
"""
locator = RoiLocator(**roiloc_cfg, mask=mask)

right_mri, left_mri = locator.fit_transform(mri)
right_mri, left_mri = locator.fit_transform(mri, outprefix=outprefix)

return locator, right_mri, left_mri

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ deepsparse_support = 'hsf.engines:print_deepsparse_support'
[tool.poetry.dependencies]
python = "^3.7.1"
torchio = "^0.18.56"
roiloc = "^0.2.7"
roiloc = "^0.2.8"
onnxruntime = "^1.8.1"
hydra-core = "^1.1.1"
wget = "^3.2"
Expand Down

0 comments on commit a9f370e

Please sign in to comment.