Skip to content

Commit

Permalink
Merge pull request #298 from juaml/feat/get-template
Browse files Browse the repository at this point in the history
[ENH]: Introduce `get_template` for getting templates
  • Loading branch information
synchon authored Feb 9, 2024
2 parents 4397f3f + 1d5d635 commit 03d5e91
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/298.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introduce :func:`.get_template` to fetch template space image tailored to a target data by `Synchon Mandal`_
2 changes: 2 additions & 0 deletions junifer/api/tests/test_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def test_get_dependency_information_short() -> None:
"ruamel.yaml",
"httpx",
"tqdm",
"templateflow",
]
if int(pl.python_version_tuple()[1]) < 10:
dependency_list.append("importlib_metadata")
Expand All @@ -68,6 +69,7 @@ def test_get_dependency_information_long() -> None:
"ruamel.yaml",
"httpx",
"tqdm",
"templateflow",
]
for key in dependency_list:
assert key in dependency_information_keys
Expand Down
2 changes: 1 addition & 1 deletion junifer/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
get_mask,
)

from .template_spaces import get_xfm
from .template_spaces import get_template, get_xfm

from . import utils
74 changes: 73 additions & 1 deletion junifer/data/template_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
# License: AGPL

from pathlib import Path
from typing import Union
from typing import Any, Dict, Optional, Union

import httpx
import nibabel as nib
import numpy as np
from templateflow import api as tflow

from ..utils import logger, raise_error
from .utils import closest_resolution


def get_xfm(
Expand Down Expand Up @@ -89,3 +93,71 @@ def get_xfm(
f.write(chunk)

return xfm_file_path


def get_template(
space: str,
target_data: Dict[str, Any],
extra_input: Optional[Dict[str, Any]] = None,
) -> nib.Nifti1Image:
"""Get template for the space, tailored for the target image.
Parameters
----------
space : str
The name of the template space.
target_data : dict
The corresponding item of the data object for which the template space
will be loaded.
extra_input : dict, optional
The other fields in the data object. Useful for accessing other data
types (default None).
Returns
-------
Nifti1Image
The template image.
Raises
------
ValueError
If ``space`` is invalid.
RuntimeError
If template in the required resolution is not found.
"""
# Check for invalid space; early check to raise proper error
if space not in tflow.templates():
raise_error(f"Unknown template space: {space}")

# Get the min of the voxels sizes and use it as the resolution
target_img = target_data["data"]
resolution = np.min(target_img.header.get_zooms()[:3]).astype(int)

# Fetch available resolutions for the template
available_resolutions = [
int(min(val["zooms"]))
for val in tflow.get_metadata(space)["res"].values()
]
# Use the closest resolution if desired resolution is not found
resolution = closest_resolution(resolution, available_resolutions)

logger.info(f"Downloading template {space} in resolution {resolution}")
# Retrieve template
try:
template_path = tflow.get(
space,
raise_empty=True,
resolution=resolution,
suffix="T1w",
desc=None,
extension="nii.gz",
)
except Exception: # noqa: BLE001
raise_error(
f"Template {space} not found in the required resolution "
f"{resolution}",
klass=RuntimeError,
)
else:
return nib.load(template_path) # type: ignore
41 changes: 40 additions & 1 deletion junifer/data/tests/test_template_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import socket
from pathlib import Path

import nibabel as nib
import pytest

from junifer.data import get_xfm
from junifer.data import get_template, get_xfm
from junifer.datareader import DefaultDataReader
from junifer.testing.datagrabbers import OasisVBMTestingDataGrabber


@pytest.mark.skipif(
Expand All @@ -28,3 +31,39 @@ def test_get_xfm(tmp_path: Path) -> None:
src="MNI152NLin6Asym", dst="MNI152NLin2009cAsym", xfms_dir=tmp_path
)
assert isinstance(xfm_path, Path)


def test_get_template() -> None:
"""Test tailored template image fetch."""
with OasisVBMTestingDataGrabber() as dg:
element = dg["sub-01"]
element_data = DefaultDataReader().fit_transform(element)
vbm_gm = element_data["VBM_GM"]
# Get tailored parcellation
tailored_template = get_template(
space=vbm_gm["space"], target_data=vbm_gm
)
assert isinstance(tailored_template, nib.Nifti1Image)


def test_get_template_invalid_space() -> None:
"""Test invalid space check for template fetch."""
with OasisVBMTestingDataGrabber() as dg:
element = dg["sub-01"]
element_data = DefaultDataReader().fit_transform(element)
vbm_gm = element_data["VBM_GM"]
# Get tailored parcellation
with pytest.raises(ValueError, match="Unknown template space:"):
_ = get_template(space="andromeda", target_data=vbm_gm)


def test_get_template_closest_resolution() -> None:
"""Test closest resolution check for template fetch."""
with OasisVBMTestingDataGrabber() as dg:
element = dg["sub-01"]
element_data = DefaultDataReader().fit_transform(element)
vbm_gm = element_data["VBM_GM"]
# Change header resolution to fetch closest resolution
element_data["VBM_GM"]["data"].header.set_zooms((3, 3, 3))
template = get_template(space=vbm_gm["space"], target_data=vbm_gm)
assert isinstance(template, nib.Nifti1Image)
4 changes: 2 additions & 2 deletions junifer/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@


def closest_resolution(
resolution: Optional[float],
resolution: Optional[Union[float, int]],
valid_resolution: Union[List[float], List[int], np.ndarray],
) -> Union[float, int]:
"""Find the closest resolution.
Parameters
----------
resolution : float, optional
resolution : float or int, optional
The given resolution. If None, will return the highest resolution
(default None).
valid_resolution : list of float or int, or np.ndarray
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"h5py>=3.8.0,<3.10",
"httpx[http2]==0.26.0",
"tqdm==4.66.1",
"templateflow>=23.0.0",
]
dynamic = ["version"]

Expand Down Expand Up @@ -188,8 +189,10 @@ known-third-party =[
"nilearn",
"sqlalchemy",
"yaml",
"importlib_metadata",
"httpx",
"tqdm",
"templateflow",
"bct",
"neurokit2",
"pytest",
Expand Down

0 comments on commit 03d5e91

Please sign in to comment.