Skip to content

Commit

Permalink
Merge pull request #284 from juaml/fix/inherit-mask-single-warp
Browse files Browse the repository at this point in the history
[BUG]: Mask "inherit" will be warped twice if working in native space.
  • Loading branch information
synchon authored Jan 15, 2024
2 parents 4d9e509 + 8e27ca2 commit 4f42f76
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/284.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid warping mask preprocessed with :class:`.fMRIPrepConfoundRemover` and used by markers with ``mask="inherit"`` in subject-native template space by `Fede Raimondo`_ and `Synchon Mandal`_
21 changes: 14 additions & 7 deletions junifer/data/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,16 @@ def get_mask( # noqa: C901
f"because the item ({inherited_mask_item}) does not exist."
)
mask_img = extra_input[inherited_mask_item]["data"]
mask_space = target_data["space"]
# Starting with new mask
else:
# Load mask
mask_object, _, mask_space = load_mask(
mask_name, path_only=False, resolution=resolution
)
# Replace mask space with target space if mask's space is inherit
if mask_space == "inherit":
mask_space = target_data["space"]
# If mask is callable like from nilearn
if callable(mask_object):
if mask_params is None:
Expand All @@ -306,20 +310,22 @@ def get_mask( # noqa: C901
interpolation="nearest",
copy=True,
)
all_spaces.append(mask_space)
all_spaces.append(mask_space)
all_masks.append(mask_img)

# Multiple masks, need intersection / union
if len(all_masks) > 1:
# Filter out "inherit" and make a set for spaces
filtered_spaces = set(filter(lambda x: x != "inherit", all_spaces))
# Make a set of unique spaces
unique_spaces = set(all_spaces)
# Intersect / union of masks only if all masks are in the same space
if len(filtered_spaces) == 1:
if len(unique_spaces) == 1:
mask_img = intersect_masks(all_masks, **intersect_params)
# Store the mask space for further checks
mask_space = next(iter(unique_spaces))
else:
raise_error(
msg=(
f"Masks are in different spaces: {filtered_spaces}, "
f"Masks are in different spaces: {unique_spaces}, "
"unable to merge."
),
klass=RuntimeError,
Expand All @@ -333,9 +339,10 @@ def get_mask( # noqa: C901
"when there is only one mask."
)
mask_img = all_masks[0]
mask_space = all_spaces[0]

# Warp mask if target data is native
if target_data["space"] == "native":
# Warp mask if target data is native and mask space is not native
if target_data["space"] == "native" and target_data["space"] != mask_space:
# Check for extra inputs
if extra_input is None:
raise_error(
Expand Down
12 changes: 6 additions & 6 deletions junifer/data/tests/test_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,9 @@ def test_get_mask_inherit() -> None:

# Now get the mask using the inherit functionality, passing the
# computed mask as extra data
extra_input = {"BOLD_MASK": {"data": gm_mask}}
extra_input = {
"BOLD_MASK": {"data": gm_mask, "space": input["BOLD"]["space"]}
}
input["BOLD"]["mask_item"] = "BOLD_MASK"
mask2 = get_mask(
masks="inherit", target_data=input["BOLD"], extra_input=extra_input
Expand All @@ -405,11 +407,9 @@ def test_get_mask_inherit() -> None:
@pytest.mark.parametrize(
"masks,params",
[
(["GM_prob0.2", "compute_brain_mask"], {}),
(
["GM_prob0.2", "compute_brain_mask"],
{"threshold": 0.2},
),
(["GM_prob0.2", "GM_prob0.2_cortex"], {}),
(["compute_brain_mask", "compute_background_mask"], {}),
(["compute_brain_mask", "compute_epi_mask"], {}),
],
)
def test_get_mask_multiple(
Expand Down
5 changes: 4 additions & 1 deletion junifer/preprocess/confounds/fmriprep_confound_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,10 @@ def _remove_confounds(
# this allows to use "inherit" down the pipeline
if extra_input is not None:
logger.debug("Setting mask_item")
extra_input["BOLD_mask"] = {"data": mask_img}
extra_input["BOLD_mask"] = {
"data": mask_img,
"space": input["space"],
}
input["mask_item"] = "BOLD_mask"

logger.info("Cleaning image")
Expand Down

0 comments on commit 4f42f76

Please sign in to comment.