Skip to content

Commit

Permalink
Merge pull request #414 from juaml/fix/non-native-data-warping
Browse files Browse the repository at this point in the history
[ENH]: Simplify space warping for parcellations and masks
  • Loading branch information
synchon authored Dec 9, 2024
2 parents 12ecdb2 + c93d589 commit 69275b2
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/414.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Simplify logic for parcellation and mask space warping by `Synchon Mandal`_
45 changes: 26 additions & 19 deletions junifer/data/masks/_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
)

import nibabel as nib
import nilearn.image as nimg
import numpy as np
from nilearn.image import get_data, new_img_like, resample_to_img
from nilearn.masking import (
compute_background_mask,
compute_epi_mask,
Expand Down Expand Up @@ -168,14 +168,14 @@ def compute_brain_mask(
)
# Resample template to target image
else:
resampled_template = resample_to_img(
resampled_template = nimg.resample_to_img(
source_img=template, target_img=target_data["data"]
)

# Threshold resampled template and get mask
mask = (get_data(resampled_template) >= threshold).astype("int8")
mask = (nimg.get_data(resampled_template) >= threshold).astype("int8")

return new_img_like(target_data["data"], mask) # type: ignore
return nimg.new_img_like(target_data["data"], mask) # type: ignore


class MaskRegistry(BasePipelineDataRegistry, metaclass=Singleton):
Expand Down Expand Up @@ -523,7 +523,7 @@ def get( # noqa: C901
)
logger.debug("Resampling inherited mask to target image.")
# Resample inherited mask to target image
mask_img = resample_to_img(
mask_img = nimg.resample_to_img(
source_img=mask_img,
target_img=target_data["data"],
)
Expand Down Expand Up @@ -568,34 +568,41 @@ def get( # noqa: C901
"mask."
)

# Set here to simplify things later
mask_img: nib.nifti1.Nifti1Image = mask_object

# Resample and warp mask to standard space
if mask_space != target_std_space:
logger.debug(
f"Warping {t_mask} to {target_std_space} space "
"using ants."
"using ANTs."
)
mask_img = ANTsMaskWarper().warp(
mask_name=mask_name,
mask_img=mask_object,
mask_img=mask_img,
src=mask_space,
dst=target_std_space,
target_data=target_data,
warp_data=warper_spec,
)
# Remove extra dimension added by ANTs
mask_img = nimg.math_img(
"np.squeeze(img)", img=mask_img
)

else:
# Resample mask to target image; no further warping
if target_space != "native":
# No warping is going to happen, just resampling,
# because we are in the correct space
logger.debug(f"Resampling {t_mask} to target image.")
if target_space != "native":
mask_img = resample_to_img(
source_img=mask_object,
target_img=target_data["data"],
)
# Set mask_img in case no warping happens before this
else:
mask_img = mask_object
# Resample and warp mask if target data is native
if target_space == "native":
mask_img = nimg.resample_to_img(
source_img=mask_img,
target_img=target_img,
)
else:
# Warp mask if target space is native as
# either the image is in the right non-native space or
# it's warped from one non-native space to another
# non-native space
logger.debug(
"Warping mask to native space using "
f"{warper_spec['warper']}."
Expand Down
40 changes: 18 additions & 22 deletions junifer/data/parcellations/_parcellations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

import httpx
import nibabel as nib
import nilearn.image as nimg
import numpy as np
import pandas as pd
from nilearn import datasets, image
from nilearn import datasets

from ...utils import logger, raise_error, warn_with_log
from ...utils.singleton import Singleton
Expand Down Expand Up @@ -470,28 +471,23 @@ def get(
warp_data=None,
)
# Remove extra dimension added by ANTs
img = image.math_img("np.squeeze(img)", img=raw_img)
# Set correct affine as resolution won't be correct
img = image.resample_img(
img=img,
target_affine=target_img.affine,
img = nimg.math_img("np.squeeze(img)", img=raw_img)

if target_space != "native":
# No warping is going to happen, just resampling, because
# we are in the correct space
logger.debug(f"Resampling {name} to target image.")
# Resample parcellation to target image
img = nimg.resample_to_img(
source_img=img,
target_img=target_img,
interpolation="nearest",
copy=True,
)
else:
if target_space != "native":
# No warping is going to happen, just resampling, because
# we are in the correct space
logger.debug(f"Resampling {name} to target image.")
# Resample parcellation to target image
img = image.resample_to_img(
source_img=img,
target_img=target_img,
interpolation="nearest",
copy=True,
)

# Warp parcellation if target space is native
if target_space == "native":
# Warp parcellation if target space is native as either
# the image is in the right non-native space or it's
# warped from one non-native space to another non-native space
logger.debug(
"Warping parcellation to native space using "
f"{warper_spec['warper']}."
Expand Down Expand Up @@ -1807,7 +1803,7 @@ def merge_parcellations(
"The parcellations have different resolutions!"
"Resampling all parcellations to the first one in the list."
)
t_parc = image.resample_to_img(
t_parc = nimg.resample_to_img(
t_parc, ref_parc, interpolation="nearest", copy=True
)

Expand All @@ -1833,6 +1829,6 @@ def merge_parcellations(
"parcellation that was first in the list."
)

parcellation_img_res = image.new_img_like(parcellations_list[0], parc_data)
parcellation_img_res = nimg.new_img_like(parcellations_list[0], parc_data)

return parcellation_img_res, labels
4 changes: 0 additions & 4 deletions junifer/markers/reho/tests/test_reho_parcels.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ def test_ReHoParcels(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None:
@pytest.mark.skipif(
_check_afni() is False, reason="requires AFNI to be in PATH"
)
@pytest.mark.xfail(
reason="junifer ReHo needs to use the correct mask",
raises=AssertionError,
)
def test_ReHoParcels_comparison(tmp_path: Path) -> None:
"""Test ReHoParcels implementation comparison.
Expand Down

0 comments on commit 69275b2

Please sign in to comment.