Skip to content

Commit

Permalink
Merge pull request #411 from juaml/refactor/parcellation-logic
Browse files Browse the repository at this point in the history
[ENH]: Rework parcellation logic
  • Loading branch information
synchon authored Dec 6, 2024
2 parents ed5db99 + b86d632 commit 12ecdb2
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 53 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/411.change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:meth:`.ParcellationRegistry.load` now has a new parameter ``target_space`` to provide parcellation in highest possible resolution if it does not match the space of the parcellation by `Fede Raimondo`_ and `Synchon Mandal`_
1 change: 1 addition & 0 deletions docs/changes/newsfragments/411.enh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor the parcellation logic in the pipeline to account for and optimise space transformations and merges by `Synchon Mandal`_ and `Fede Raimondo`_
89 changes: 55 additions & 34 deletions junifer/data/parcellations/_parcellations.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def deregister(self, name: str) -> None:
def load(
self,
name: str,
target_space: str,
parcellations_dir: Union[str, Path, None] = None,
resolution: Optional[float] = None,
path_only: bool = False,
Expand All @@ -282,6 +283,8 @@ def load(
----------
name : str
The name of the parcellation.
target_space : str
The desired space of the parcellation.
parcellations_dir : str or pathlib.Path, optional
Path where the parcellations files are stored. The default location
is "$HOME/junifer/data/parcellations" (default None).
Expand Down Expand Up @@ -328,6 +331,14 @@ def load(
else:
space = parcellation_definition["space"]

# Check and get highest resolution
if space != target_space:
logger.info(
f"Parcellation will be warped from {space} to {target_space} "
"using highest resolution"
)
resolution = None

# Check if the parcellation family is custom or built-in
if t_family == "CustomUserParcellation":
parcellation_fname = Path(parcellation_definition["path"])
Expand Down Expand Up @@ -441,13 +452,14 @@ def get(
img, labels, _, space = self.load(
name=name,
resolution=resolution,
target_space=target_space,
)

# Convert parcellation spaces if required;
# cannot be "native" due to earlier check
if space != target_std_space:
logger.debug(
f"Warping {name} to {target_std_space} space using ants."
f"Warping {name} to {target_std_space} space using ANTs."
)
raw_img = ANTsParcellationWarper().warp(
parcellation_name=name,
Expand All @@ -459,17 +471,50 @@ def get(
)
# Remove extra dimension added by ANTs
img = image.math_img("np.squeeze(img)", img=raw_img)
# Set correct affine as resolution won't be correct
img = image.resample_img(
img=img,
target_affine=target_img.affine,
interpolation="nearest",
)
else:
if target_space != "native":
# No warping is going to happen, just resampling, because
# we are in the correct space
logger.debug(f"Resampling {name} to target image.")
# Resample parcellation to target image
img = image.resample_to_img(
source_img=img,
target_img=target_img,
interpolation="nearest",
copy=True,
)

logger.debug(f"Resampling {name} to target image.")
# Resample parcellation to target image
img_to_merge = image.resample_to_img(
source_img=img,
target_img=target_img,
interpolation="nearest",
copy=True,
)
# Warp parcellation if target space is native
if target_space == "native":
logger.debug(
"Warping parcellation to native space using "
f"{warper_spec['warper']}."
)
# extra_input check done earlier and warper_spec exists
if warper_spec["warper"] == "fsl":
img = FSLParcellationWarper().warp(
parcellation_name="native",
parcellation_img=img,
target_data=target_data,
warp_data=warper_spec,
)
elif warper_spec["warper"] == "ants":
img = ANTsParcellationWarper().warp(
parcellation_name="native",
parcellation_img=img,
src="",
dst="native",
target_data=target_data,
warp_data=warper_spec,
)

all_parcellations.append(img_to_merge)
all_parcellations.append(img)
all_labels.append(labels)

# Avoid merging if there is only one parcellation
Expand All @@ -485,30 +530,6 @@ def get(
labels_lists=all_labels,
)

# Warp parcellation if target space is native
if target_space == "native":
logger.debug(
"Warping parcellation to native space using "
f"{warper_spec['warper']}."
)
# extra_input check done earlier and warper_spec exists
if warper_spec["warper"] == "fsl":
resampled_parcellation_img = FSLParcellationWarper().warp(
parcellation_name="native",
parcellation_img=resampled_parcellation_img,
target_data=target_data,
warp_data=warper_spec,
)
elif warper_spec["warper"] == "ants":
resampled_parcellation_img = ANTsParcellationWarper().warp(
parcellation_name="native",
parcellation_img=resampled_parcellation_img,
src="",
dst="native",
target_data=target_data,
warp_data=warper_spec,
)

return resampled_parcellation_img, labels


Expand Down
64 changes: 45 additions & 19 deletions junifer/data/parcellations/tests/test_parcellations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def test_register_already_registered() -> None:
space="MNI152Lin",
)
assert (
ParcellationRegistry().load("testparc", path_only=True)[2].name
ParcellationRegistry()
.load("testparc", target_space="MNI152Lin", path_only=True)[2]
.name
== "testparc.nii.gz"
)

Expand All @@ -81,7 +83,9 @@ def test_register_already_registered() -> None:
)

assert (
ParcellationRegistry().load("testparc", path_only=True)[2].name
ParcellationRegistry()
.load("testparc", target_space="MNI152Lin", path_only=True)[2]
.name
== "testparc2.nii.gz"
)

Expand All @@ -96,7 +100,8 @@ def test_parcellation_wrong_labels_values(tmp_path: Path) -> None:
"""
schaefer, labels, schaefer_path, _ = ParcellationRegistry().load(
"Schaefer100x7"
"Schaefer100x7",
"MNI152NLin6Asym",
)
assert schaefer is not None

Expand All @@ -106,15 +111,15 @@ def test_parcellation_wrong_labels_values(tmp_path: Path) -> None:
)

with pytest.raises(ValueError, match=r"has 100 parcels but 10"):
ParcellationRegistry().load("WrongLabels")
ParcellationRegistry().load("WrongLabels", "MNI152NLin6Asym")

# Test wrong number of labels
ParcellationRegistry().register(
"WrongLabels2", schaefer_path, [*labels, "wrong"], "MNI152Lin"
)

with pytest.raises(ValueError, match=r"has 100 parcels but 101"):
ParcellationRegistry().load("WrongLabels2")
ParcellationRegistry().load("WrongLabels2", "MNI152NLin6Asym")

schaefer_data = schaefer.get_fdata().copy()
schaefer_data[schaefer_data == 50] = 0
Expand All @@ -126,7 +131,7 @@ def test_parcellation_wrong_labels_values(tmp_path: Path) -> None:
"WrongValues", new_schaefer_path, labels[:-1], "MNI152Lin"
)
with pytest.raises(ValueError, match=r"must have all the values in the"):
ParcellationRegistry().load("WrongValues")
ParcellationRegistry().load("WrongValues", "MNI152NLin6Asym")

schaefer_data = schaefer.get_fdata().copy()
schaefer_data[schaefer_data == 50] = 200
Expand All @@ -138,7 +143,7 @@ def test_parcellation_wrong_labels_values(tmp_path: Path) -> None:
"WrongValues2", new_schaefer_path, labels, "MNI152Lin"
)
with pytest.raises(ValueError, match=r"must have all the values in the"):
ParcellationRegistry().load("WrongValues2")
ParcellationRegistry().load("WrongValues2", "MNI152NLin6Asym")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -202,7 +207,7 @@ def test_register(
assert name in ParcellationRegistry().list
# Load registered parcellation
_, lbl, fname, parcellation_space = ParcellationRegistry().load(
name=name, path_only=True
name=name, target_space=space, path_only=True
)
# Check values for registered parcellation
assert lbl == parcels_labels
Expand Down Expand Up @@ -239,7 +244,7 @@ def test_list_correct(parcellation_name: str) -> None:
def test_load_incorrect() -> None:
"""Test loading of invalid parcellations."""
with pytest.raises(ValueError, match=r"not found"):
ParcellationRegistry().load("wrongparcellation")
ParcellationRegistry().load("wrongparcellation", "MNI152NLin6Asym")


def test_retrieve_parcellation_incorrect() -> None:
Expand Down Expand Up @@ -323,6 +328,7 @@ def test_schaefer(
# Load parcellation
img, label, img_path, space = ParcellationRegistry().load(
name=parcellation_name,
target_space="MNI152NLin6Asym",
parcellations_dir=tmp_path,
resolution=resolution,
)
Expand Down Expand Up @@ -392,6 +398,7 @@ def test_suit(tmp_path: Path, space_key: str, space: str) -> None:
# Load parcellation
img, label, img_path, parcellation_space = ParcellationRegistry().load(
name=f"SUITx{space_key}",
target_space=space,
parcellations_dir=tmp_path,
)
assert img is not None
Expand Down Expand Up @@ -441,7 +448,9 @@ def test_tian_3T_6thgeneration(
assert "TianxS4x3TxMNI6thgeneration" in parcellations
# Load parcellation
img, lbl, fname, parcellation_space_1 = ParcellationRegistry().load(
name=f"TianxS{scale}x3TxMNI6thgeneration", parcellations_dir=tmp_path
name=f"TianxS{scale}x3TxMNI6thgeneration",
parcellations_dir=tmp_path,
target_space="MNI152NLin2009cAsym",
)
fname1 = f"Tian_Subcortex_S{scale}_3T_1mm.nii.gz"
assert img is not None
Expand All @@ -452,6 +461,7 @@ def test_tian_3T_6thgeneration(
# Load parcellation
img, lbl, fname, parcellation_space_2 = ParcellationRegistry().load(
name=f"TianxS{scale}x3TxMNI6thgeneration",
target_space="MNI152NLin6Asym",
parcellations_dir=tmp_path,
resolution=2,
)
Expand Down Expand Up @@ -489,6 +499,7 @@ def test_tian_3T_nonlinear2009cAsym(
# Load parcellation
img, lbl, fname, space = ParcellationRegistry().load(
name=f"TianxS{scale}x3TxMNInonlinear2009cAsym",
target_space="MNI152NLin2009cAsym",
parcellations_dir=tmp_path,
)
fname1 = f"Tian_Subcortex_S{scale}_3T_2009cAsym.nii.gz"
Expand Down Expand Up @@ -524,7 +535,9 @@ def test_tian_7T_6thgeneration(
assert "TianxS4x7TxMNI6thgeneration" in parcellations
# Load parcellation
img, lbl, fname, space = ParcellationRegistry().load(
name=f"TianxS{scale}x7TxMNI6thgeneration", parcellations_dir=tmp_path
name=f"TianxS{scale}x7TxMNI6thgeneration",
target_space="MNI152NLin6Asym",
parcellations_dir=tmp_path,
)
fname1 = f"Tian_Subcortex_S{scale}_7T.nii.gz"
assert img is not None
Expand Down Expand Up @@ -611,7 +624,9 @@ def test_aicha(tmp_path: Path, version: int) -> None:
assert f"AICHA_v{version}" in ParcellationRegistry().list
# Load parcellation
img, label, img_path, space = ParcellationRegistry().load(
name=f"AICHA_v{version}", parcellations_dir=tmp_path
name=f"AICHA_v{version}",
target_space="IXI549Space",
parcellations_dir=tmp_path,
)
assert img is not None
assert img_path.name == "AICHA.nii"
Expand Down Expand Up @@ -680,6 +695,7 @@ def test_shen(
# Load parcellation
img, label, img_path, space = ParcellationRegistry().load(
name=f"Shen_{year}_{n_rois}",
target_space="MNI152NLin2009cAsym",
parcellations_dir=tmp_path,
resolution=resolution,
)
Expand Down Expand Up @@ -875,7 +891,8 @@ def test_yan(
)
# Load parcellation
img, label, img_path, space = ParcellationRegistry().load(
name=parcellation_name, # type: ignore
name=parcellation_name,
target_space="MNI152NLin6Asym",
parcellations_dir=tmp_path,
resolution=resolution,
)
Expand Down Expand Up @@ -1012,6 +1029,7 @@ def test_brainnetome(
# Load parcellation
img, label, img_path, space = ParcellationRegistry().load(
name=parcellation_name,
target_space="MNI152NLin6Asym",
parcellations_dir=tmp_path,
resolution=resolution,
)
Expand Down Expand Up @@ -1044,10 +1062,11 @@ def test_merge_parcellations() -> None:
"""Test merging parcellations."""
# load some parcellations for testing
schaefer_parcellation, schaefer_labels, _, _ = ParcellationRegistry().load(
"Schaefer100x17"
"Schaefer100x17", target_space="MNI152NLin2009cAsym"
)
tian_parcellation, tian_labels, _, _ = ParcellationRegistry().load(
"TianxS2x3TxMNInonlinear2009cAsym"
"TianxS2x3TxMNInonlinear2009cAsym",
target_space="MNI152NLin2009cAsym",
)
# prepare the list of the actual parcellations
parcellation_list = [schaefer_parcellation, tian_parcellation]
Expand Down Expand Up @@ -1079,7 +1098,9 @@ def test_merge_parcellations_3D_multiple_non_overlapping(
"""
# Get the testing parcellation
parcellation, labels, _, _ = ParcellationRegistry().load("Schaefer100x7")
parcellation, labels, _, _ = ParcellationRegistry().load(
"Schaefer100x7", target_space="MNI152NLin2009cAsym"
)

assert parcellation is not None

Expand Down Expand Up @@ -1114,7 +1135,9 @@ def test_merge_parcellations_3D_multiple_overlapping() -> None:
"""Test merge_parcellations with multiple overlapping parcellations."""

# Get the testing parcellation
parcellation, labels, _, _ = ParcellationRegistry().load("Schaefer100x7")
parcellation, labels, _, _ = ParcellationRegistry().load(
"Schaefer100x7", target_space="MNI152NLin2009cAsym"
)

assert parcellation is not None

Expand Down Expand Up @@ -1149,7 +1172,9 @@ def test_merge_parcellations_3D_multiple_duplicated_labels() -> None:
"""Test merge_parcellations with duplicated labels."""

# Get the testing parcellation
parcellation, labels, _, _ = ParcellationRegistry().load("Schaefer100x7")
parcellation, labels, _, _ = ParcellationRegistry().load(
"Schaefer100x7", target_space="MNI152NLin2009cAsym"
)

assert parcellation is not None

Expand Down Expand Up @@ -1198,6 +1223,7 @@ def test_get_single() -> None:
# Get raw parcellation
raw_parcellation, raw_labels, _, _ = ParcellationRegistry().load(
"TianxS1x3TxMNInonlinear2009cAsym",
target_space="MNI152NLin2009cAsym",
resolution=1.5,
)
resampled_raw_parcellation = resample_to_img(
Expand Down Expand Up @@ -1240,7 +1266,7 @@ def test_get_multi_same_space() -> None:
]
for name in parcellations_names:
img, labels, _, _ = ParcellationRegistry().load(
name=name, resolution=1.5
name=name, target_space="MNI152NLin2009cAsym", resolution=1.5
)
# Resample raw parcellations
resampled_img = resample_to_img(
Expand Down
4 changes: 4 additions & 0 deletions junifer/markers/reho/tests/test_reho_parcels.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def test_ReHoParcels(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None:
@pytest.mark.skipif(
_check_afni() is False, reason="requires AFNI to be in PATH"
)
@pytest.mark.xfail(
reason="junifer ReHo needs to use the correct mask",
raises=AssertionError,
)
def test_ReHoParcels_comparison(tmp_path: Path) -> None:
"""Test ReHoParcels implementation comparison.
Expand Down

0 comments on commit 12ecdb2

Please sign in to comment.