Skip to content

Commit

Permalink
Merge pull request #398 from juaml/fix/multiple-dg-meta-update
Browse files Browse the repository at this point in the history
[BUG]: Fix metadata update for `MultipleDataGrabber`
  • Loading branch information
synchon authored Nov 21, 2024
2 parents 3a05a93 + 92bbd72 commit 1944462
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 62 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/398.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix metadata update for :class:`.MultipleDataGrabber` and adjust :meth:`.PatternDataGrabber.get_elements` to check ``list``-like data type values by `Fede Raimondo`_ and `Synchon Mandal`_
8 changes: 6 additions & 2 deletions junifer/datagrabber/multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,12 @@ def __getitem__(self, element: Union[str, tuple]) -> dict:

# Update all the metas again
for kind in out:
self.update_meta(out[kind], "datagrabber")
out[kind]["meta"]["datagrabber"]["datagrabbers"] = metas
to_update = out[kind]
if not isinstance(to_update, list):
to_update = [to_update]
for t_kind in to_update:
self.update_meta(t_kind, "datagrabber")
t_kind["meta"]["datagrabber"]["datagrabbers"] = metas
return out

def __enter__(self) -> "MultipleDataGrabber":
Expand Down
110 changes: 58 additions & 52 deletions junifer/datagrabber/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np

from ..api.decorators import register_datagrabber
from ..typing import DataGrabberPatterns
from ..utils import logger, raise_error
from .base import BaseDataGrabber
from .pattern_validation_mixin import PatternValidationMixin
Expand Down Expand Up @@ -171,7 +172,7 @@ class PatternDataGrabber(BaseDataGrabber, PatternValidationMixin):
def __init__(
self,
types: list[str],
patterns: dict[str, dict[str, str]],
patterns: DataGrabberPatterns,
replacements: Union[list[str], str],
datadir: Union[str, Path],
confounds_format: Optional[str] = None,
Expand Down Expand Up @@ -478,58 +479,63 @@ def get_elements(self) -> list:
t_type = self.types[t_idx]
types_element = set()

# Get the pattern dict
t_pattern = self.patterns[t_type]
# Conditional fetch of base pattern for getting elements
pattern = None
# Try for data type pattern
pattern = t_pattern.get("pattern")
# Try for nested data type pattern
if pattern is None and self.partial_pattern_ok:
for v in t_pattern.values():
if isinstance(v, dict) and "pattern" in v:
pattern = v["pattern"]
break

# Replace the pattern
(
re_pattern,
glob_pattern,
t_replacements,
) = self._replace_patterns_regex(pattern)
for fname in self.datadir.glob(glob_pattern):
suffix = fname.relative_to(self.datadir).as_posix()
m = re.match(re_pattern, suffix)
if m is not None:
# Find the groups of replacements present in the pattern
# If one replacement is not present, set it to None.
# We will take care of this in the intersection
t_element = tuple([m.group(k) for k in t_replacements])
if len(self.replacements) == 1:
t_element = t_element[0]
types_element.add(t_element)
# TODO: does this make sense as elements is always None
if elements is None:
elements = types_element
else:
# Do the intersection by filtering out elements in which
# the replacements are not None
if t_replacements == self.replacements:
elements.intersection(types_element)
# Data type dictionary
patterns = self.patterns[t_type]
# Conditional for list dtype vals like Warp
if not isinstance(patterns, list):
patterns = [patterns]
for t_pattern in patterns:
# Conditional fetch of base pattern for getting elements
pattern = None
# Try for data type pattern
pattern = t_pattern.get("pattern")
# Try for nested data type pattern
if pattern is None and self.partial_pattern_ok:
for v in t_pattern.values():
if isinstance(v, dict) and "pattern" in v:
pattern = v["pattern"]
break

# Replace the pattern
(
re_pattern,
glob_pattern,
t_replacements,
) = self._replace_patterns_regex(pattern)
for fname in self.datadir.glob(glob_pattern):
suffix = fname.relative_to(self.datadir).as_posix()
m = re.match(re_pattern, suffix)
if m is not None:
# Find the groups of replacements present in the
# pattern. If one replacement is not present, set it
# to None. We will take care of this in the
# intersection.
t_element = tuple([m.group(k) for k in t_replacements])
if len(self.replacements) == 1:
t_element = t_element[0]
types_element.add(t_element)
# TODO: does this make sense as elements is always None
if elements is None:
elements = types_element
else:
t_repl_idx = [
i
for i, v in enumerate(self.replacements)
if v in t_replacements
]
new_elements = set()
for t_element in elements:
if (
tuple(np.array(t_element)[t_repl_idx])
in types_element
):
new_elements.add(t_element)
elements = new_elements
# Do the intersection by filtering out elements in which
# the replacements are not None
if t_replacements == self.replacements:
elements.intersection(types_element)
else:
t_repl_idx = [
i
for i, v in enumerate(self.replacements)
if v in t_replacements
]
new_elements = set()
for t_element in elements:
if (
tuple(np.array(t_element)[t_repl_idx])
in types_element
):
new_elements.add(t_element)
elements = new_elements
if elements is None:
elements = set()
return list(elements)
6 changes: 3 additions & 3 deletions junifer/datagrabber/pattern_validation_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Authors: Synchon Mandal <s.mandal@fz-juelich.de>
# License: AGPL

from typing import Union

from ..typing import DataGrabberPatterns
from ..utils import logger, raise_error, warn_with_log


Expand Down Expand Up @@ -96,7 +96,7 @@ def _validate_types(self, types: list[str]) -> None:
def _validate_replacements(
self,
replacements: list[str],
patterns: dict[str, Union[dict[str, str], list[dict[str, str]]]],
patterns: DataGrabberPatterns,
partial_pattern_ok: bool,
) -> None:
"""Validate the replacements.
Expand Down Expand Up @@ -263,7 +263,7 @@ def validate_patterns(
self,
types: list[str],
replacements: list[str],
patterns: dict[str, Union[dict[str, str], list[dict[str, str]]]],
patterns: DataGrabberPatterns,
partial_pattern_ok: bool = False,
) -> None:
"""Validate the patterns.
Expand Down
4 changes: 2 additions & 2 deletions junifer/datagrabber/tests/test_datalad_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
_testing_dataset = {
"example_bids": {
"uri": "https://gin.g-node.org/juaml/datalad-example-bids",
"commit": "b87897cbe51bf0ee5514becaa5c7dd76491db5ad",
"id": "8fddff30-6993-420a-9d1e-b5b028c59468",
"commit": "3f288c8725207ae0c9b3616e093e78cda192b570",
"id": "582b9696-f13f-42e4-9587-b4e62aa2a8e7",
},
"example_bids_ses": {
"uri": "https://gin.g-node.org/juaml/datalad-example-bids-ses",
Expand Down
50 changes: 48 additions & 2 deletions junifer/datagrabber/tests/test_multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_MultipleDataGrabber() -> None:
dg1 = PatternDataladDataGrabber(
rootdir=rootdir,
uri=repo_uri,
types=["T1w"],
types=["T1w", "Warp"],
patterns={
"T1w": {
"pattern": (
Expand All @@ -44,6 +44,28 @@ def test_MultipleDataGrabber() -> None:
"space": "native",
},
},
"Warp": [
{
"pattern": (
"{subject}/{session}/anat/"
"{subject}_{session}_from-MNI152NLin2009cAsym_to-T1w_"
"xfm.h5"
),
"src": "MNI152NLin2009cAsym",
"dst": "native",
"warper": "ants",
},
{
"pattern": (
"{subject}/{session}/anat/"
"{subject}_{session}_from-T1w_to-MNI152NLin2009cAsym_"
"xfm.h5"
),
"src": "native",
"dst": "MNI152NLin2009cAsym",
"warper": "ants",
},
],
},
replacements=replacements,
)
Expand Down Expand Up @@ -75,6 +97,7 @@ def test_MultipleDataGrabber() -> None:

types = dg.get_types()
assert "T1w" in types
assert "Warp" in types
assert "BOLD" in types

expected_subs = [
Expand All @@ -90,6 +113,7 @@ def test_MultipleDataGrabber() -> None:
elem = dg[("sub-01", "ses-01")]
# Check data types
assert "T1w" in elem
assert "Warp" in elem
assert "BOLD" in elem
# Check meta
assert "meta" in elem["BOLD"]
Expand All @@ -111,14 +135,36 @@ def test_MultipleDataGrabber_no_intersection() -> None:
dg1 = PatternDataladDataGrabber(
rootdir=rootdir,
uri=_testing_dataset["example_bids"]["uri"],
types=["T1w"],
types=["T1w", "Warp"],
patterns={
"T1w": {
"pattern": (
"{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz"
),
"space": "native",
},
"Warp": [
{
"pattern": (
"{subject}/{session}/anat/"
"{subject}_{session}_from-MNI152NLin2009cAsym_to-T1w_"
"xfm.h5"
),
"src": "MNI152NLin2009cAsym",
"dst": "native",
"warper": "ants",
},
{
"pattern": (
"{subject}/{session}/anat/"
"{subject}_{session}_from-T1w_to-MNI152NLin2009cAsym_"
"xfm.h5"
),
"src": "native",
"dst": "MNI152NLin2009cAsym",
"warper": "ants",
},
],
},
replacements=replacements,
)
Expand Down
2 changes: 1 addition & 1 deletion junifer/datagrabber/tests/test_pattern_datalad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_testing_dataset = {
"example_bids": {
"uri": "https://gin.g-node.org/juaml/datalad-example-bids",
"commit": "b87897cbe51bf0ee5514becaa5c7dd76491db5ad",
"commit": "3f288c8725207ae0c9b3616e093e78cda192b570",
"id": "8fddff30-6993-420a-9d1e-b5b028c59468",
},
"example_bids_ses": {
Expand Down
2 changes: 2 additions & 0 deletions junifer/typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ __all__ = [
"ConditionalDependencies",
"ExternalDependencies",
"MarkerInOutMappings",
"DataGrabberPatterns",
]

from ._typing import (
Expand All @@ -20,4 +21,5 @@ from ._typing import (
ConditionalDependencies,
ExternalDependencies,
MarkerInOutMappings,
DataGrabberPatterns,
)
4 changes: 4 additions & 0 deletions junifer/typing/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"ConditionalDependencies",
"ExternalDependencies",
"MarkerInOutMappings",
"DataGrabberPatterns",
]


Expand Down Expand Up @@ -56,3 +57,6 @@
]
ExternalDependencies = Sequence[MutableMapping[str, Union[str, Sequence[str]]]]
MarkerInOutMappings = MutableMapping[str, MutableMapping[str, str]]
DataGrabberPatterns = dict[
str, Union[dict[str, str], Sequence[dict[str, str]]]
]
2 changes: 2 additions & 0 deletions tools/create_bids_example_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
fnames = [
f"anat/{t_sub}_T1w.nii.gz",
f"anat/{t_sub}_brain_mask.nii.gz",
f"anat/{t_sub}_from-MNI152NLin2009cAsym_to-T1w_xfm.h5",
f"anat/{t_sub}_from-T1w_to-MNI152NLin2009cAsym_xfm.h5",
f"func/{t_sub}_task-rest_bold.nii.gz",
f"func/{t_sub}_task-rest_bold.json",
f"func/{t_sub}_task-rest_brain_mask.nii.gz",
Expand Down
2 changes: 2 additions & 0 deletions tools/create_bids_example_dataset_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
fnames = [
f"anat/{t_sub}_{t_ses}_T1w.nii.gz",
f"anat/{t_sub}_{t_ses}_brain_mask.nii.gz",
f"anat/{t_sub}_{t_ses}_from-MNI152NLin2009cAsym_to-T1w_xfm.h5",
f"anat/{t_sub}_{t_ses}_from-T1w_to-MNI152NLin2009cAsym_xfm.h5",
]
if i_ses != 3: # Session 3 does not have functional data
fnames.extend(
Expand Down

0 comments on commit 1944462

Please sign in to comment.