Skip to content

Commit

Permalink
Merge pull request #319 from juaml/enh/raise-invalid-element-selector
Browse files Browse the repository at this point in the history
[ENH]: Raise/warn if element is not ran
  • Loading branch information
synchon authored Jan 8, 2025
2 parents 7e12e23 + fd25474 commit fd8cfc2
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 65 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/319.enh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Raise error when partial or complete element selectors are invalid when running the pipeline by `Synchon Mandal`_
42 changes: 31 additions & 11 deletions junifer/api/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
)
from ..preprocess import BasePreprocessor
from ..storage import BaseFeatureStorage
from ..typing import DataGrabberLike, MarkerLike, PreprocessorLike, StorageLike
from ..typing import (
DataGrabberLike,
Elements,
MarkerLike,
PreprocessorLike,
StorageLike,
)
from ..utils import logger, raise_error, warn_with_log, yaml


Expand Down Expand Up @@ -121,7 +127,7 @@ def run(
markers: list[dict],
storage: dict,
preprocessors: Optional[list[dict]] = None,
elements: Optional[list[tuple[str, ...]]] = None,
elements: Optional[Elements] = None,
) -> None:
"""Run the pipeline on the selected element.
Expand All @@ -147,14 +153,16 @@ def run(
List of preprocessors to use. Each preprocessor is a dict with at
least a key ``kind`` specifying the preprocessor to use. All other keys
are passed to the preprocessor constructor (default None).
elements : list of tuple or None, optional
elements : list or None, optional
Element(s) to process. Will be used to index the DataGrabber
(default None).
Raises
------
ValueError
If ``workdir.cleanup=False`` when ``len(elements) > 1``.
RuntimeError
If invalid element selectors are found.
"""
# Conditional to handle workdir config
Expand Down Expand Up @@ -208,10 +216,22 @@ def run(
# Fit elements
with datagrabber_object:
if elements is not None:
for t_element in datagrabber_object.filter(
elements # type: ignore
):
# Keep track of valid selectors
valid_elements = []
for t_element in datagrabber_object.filter(elements):
valid_elements.append(t_element)
mc.fit(datagrabber_object[t_element])
# Compute invalid selectors
invalid_elements = set(elements) - set(valid_elements)
# Report if invalid selectors are found
if invalid_elements:
raise_error(
msg=(
"The following element selectors are invalid:\n"
f"{invalid_elements}"
),
klass=RuntimeError,
)
else:
for t_element in datagrabber_object:
mc.fit(datagrabber_object[t_element])
Expand Down Expand Up @@ -243,7 +263,7 @@ def queue(
kind: str,
jobname: str = "junifer_job",
overwrite: bool = False,
elements: Optional[list[tuple[str, ...]]] = None,
elements: Optional[Elements] = None,
**kwargs: Union[str, int, bool, dict, tuple, list],
) -> None:
"""Queue a job to be executed later.
Expand All @@ -258,7 +278,7 @@ def queue(
The name of the job (default "junifer_job").
overwrite : bool, optional
Whether to overwrite if job directory already exists (default False).
elements : list of tuple or None, optional
elements : list or None, optional
Element(s) to process. Will be used to index the DataGrabber
(default None).
**kwargs : dict
Expand Down Expand Up @@ -341,7 +361,7 @@ def queue(
elements = dg.get_elements()
# Listify elements
if not isinstance(elements, list):
elements: list[Union[str, tuple]] = [elements]
elements: Elements = [elements]

# Check job queueing system
adapter = None
Expand Down Expand Up @@ -406,7 +426,7 @@ def reset(config: dict) -> None:

def list_elements(
datagrabber: dict,
elements: Optional[list[tuple[str, ...]]] = None,
elements: Optional[Elements] = None,
) -> str:
"""List elements of the datagrabber filtered using `elements`.
Expand All @@ -416,7 +436,7 @@ def list_elements(
DataGrabber to index. Must have a key ``kind`` with the kind of
DataGrabber to use. All other keys are passed to the DataGrabber
constructor.
elements : list of tuple or None, optional
elements : list or None, optional
Element(s) to filter using. Will be used to index the DataGrabber
(default None).
Expand Down
5 changes: 3 additions & 2 deletions junifer/api/queue_context/gnu_parallel_local_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import shutil
import textwrap
from pathlib import Path
from typing import Optional, Union
from typing import Optional

from ...typing import Elements
from ...utils import logger, make_executable, raise_error, run_ext_cmd
from .queue_context_adapter import QueueContextAdapter

Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(
job_name: str,
job_dir: Path,
yaml_config_path: Path,
elements: list[Union[str, tuple]],
elements: Elements,
pre_run: Optional[str] = None,
pre_collect: Optional[str] = None,
env: Optional[dict[str, str]] = None,
Expand Down
5 changes: 3 additions & 2 deletions junifer/api/queue_context/htcondor_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import shutil
import textwrap
from pathlib import Path
from typing import Optional, Union
from typing import Optional

from ...typing import Elements
from ...utils import logger, make_executable, raise_error, run_ext_cmd
from .queue_context_adapter import QueueContextAdapter

Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(
job_name: str,
job_dir: Path,
yaml_config_path: Path,
elements: list[Union[str, tuple]],
elements: Elements,
pre_run: Optional[str] = None,
pre_collect: Optional[str] = None,
env: Optional[dict[str, str]] = None,
Expand Down
143 changes: 115 additions & 28 deletions junifer/api/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
# License: AGPL

import logging
from contextlib import AbstractContextManager, nullcontext
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union

import pytest
from nibabel.filebasedimages import ImageFileError
from ruamel.yaml import YAML

import junifer.testing.registry # noqa: F401
from junifer.api import collect, list_elements, queue, reset, run
from junifer.datagrabber.base import BaseDataGrabber
from junifer.pipeline import PipelineComponentRegistry
from junifer.typing import Elements


# Configure YAML class
Expand All @@ -25,12 +28,37 @@
yaml.indent(mapping=2, sequence=4, offset=2)


# Kept for parametrizing
_datagrabber = {
"kind": "PartlyCloudyTestingDataGrabber",
}
_bids_ses_datagrabber = {
"kind": "PatternDataladDataGrabber",
"uri": "https://gin.g-node.org/juaml/datalad-example-bids-ses",
"types": ["T1w", "BOLD"],
"patterns": {
"T1w": {
"pattern": (
"{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz"
),
"space": "MNI152NLin6Asym",
},
"BOLD": {
"pattern": (
"{subject}/{session}/func/{subject}_{session}_task-rest_bold.nii.gz"
),
"space": "MNI152NLin6Asym",
},
},
"replacements": ["subject", "session"],
"rootdir": "example_bids_ses",
}


@pytest.fixture
def datagrabber() -> dict[str, str]:
"""Return a datagrabber as a dictionary."""
return {
"kind": "PartlyCloudyTestingDataGrabber",
}
return _datagrabber.copy()


@pytest.fixture
Expand Down Expand Up @@ -60,11 +88,48 @@ def storage() -> dict[str, str]:
}


@pytest.mark.parametrize(
"datagrabber, element, expect",
[
(
_datagrabber,
[("sub-01",)],
pytest.raises(RuntimeError, match="element selectors are invalid"),
),
(
_datagrabber,
["sub-01"],
nullcontext(),
),
(
_bids_ses_datagrabber,
["sub-01"],
pytest.raises(ImageFileError, match="is not a gzip file"),
),
(
_bids_ses_datagrabber,
[("sub-01", "ses-01")],
pytest.raises(ImageFileError, match="is not a gzip file"),
),
(
_bids_ses_datagrabber,
[("sub-01", "ses-100")],
pytest.raises(RuntimeError, match="element selectors are invalid"),
),
(
_bids_ses_datagrabber,
[("sub-100", "ses-01")],
pytest.raises(RuntimeError, match="element selectors are invalid"),
),
],
)
def test_run_single_element(
tmp_path: Path,
datagrabber: dict[str, str],
datagrabber: dict[str, Any],
markers: list[dict[str, str]],
storage: dict[str, str],
element: Elements,
expect: AbstractContextManager,
) -> None:
"""Test run function with single element.
Expand All @@ -78,21 +143,26 @@ def test_run_single_element(
Testing markers as list of dictionary.
storage : dict
Testing storage as dictionary.
element : list of str or tuple
The parametrized element.
expect : typing.ContextManager
The parametrized ContextManager object.
"""
# Set storage
storage["uri"] = str((tmp_path / "out.sqlite").resolve())
# Run operations
run(
workdir=tmp_path,
datagrabber=datagrabber,
markers=markers,
storage=storage,
elements=[("sub-01",)],
)
# Check files
files = list(tmp_path.glob("*.sqlite"))
assert len(files) == 1
with expect:
run(
workdir=tmp_path,
datagrabber=datagrabber,
markers=markers,
storage=storage,
elements=element,
)
# Check files
files = list(tmp_path.glob("*.sqlite"))
assert len(files) == 1


def test_run_single_element_with_preprocessing(
Expand Down Expand Up @@ -128,18 +198,30 @@ def test_run_single_element_with_preprocessing(
"kind": "fMRIPrepConfoundRemover",
}
],
elements=[("sub-01",)],
elements=["sub-01"],
)
# Check files
files = list(tmp_path.glob("*.sqlite"))
assert len(files) == 1


@pytest.mark.parametrize(
"element, expect",
[
(
[("sub-01",), ("sub-03",)],
pytest.raises(RuntimeError, match="element selectors are invalid"),
),
(["sub-01", "sub-03"], nullcontext()),
],
)
def test_run_multi_element_multi_output(
tmp_path: Path,
datagrabber: dict[str, str],
markers: list[dict[str, str]],
storage: dict[str, str],
element: Elements,
expect: AbstractContextManager,
) -> None:
"""Test run function with multi element and multi output.
Expand All @@ -153,22 +235,27 @@ def test_run_multi_element_multi_output(
Testing markers as list of dictionary.
storage : dict
Testing storage as dictionary.
element : list of str or tuple
The parametrized element.
expect : typing.ContextManager
The parametrized ContextManager object.
"""
# Set storage
storage["uri"] = str((tmp_path / "out.sqlite").resolve())
storage["single_output"] = False # type: ignore
# Run operations
run(
workdir=tmp_path,
datagrabber=datagrabber,
markers=markers,
storage=storage,
elements=[("sub-01",), ("sub-03",)],
)
# Check files
files = list(tmp_path.glob("*.sqlite"))
assert len(files) == 2
with expect:
run(
workdir=tmp_path,
datagrabber=datagrabber,
markers=markers,
storage=storage,
elements=element,
)
# Check files
files = list(tmp_path.glob("*.sqlite"))
assert len(files) == 2


def test_run_multi_element_single_output(
Expand Down Expand Up @@ -200,7 +287,7 @@ def test_run_multi_element_single_output(
datagrabber=datagrabber,
markers=markers,
storage=storage,
elements=[("sub-01",), ("sub-03",)],
elements=["sub-01", "sub-03"],
)
# Check files
files = list(tmp_path.glob("*.sqlite"))
Expand Down Expand Up @@ -569,7 +656,7 @@ def test_reset_run(
datagrabber=datagrabber,
markers=markers,
storage=storage,
elements=[("sub-01",)],
elements=["sub-01"],
)
# Reset operation
reset(config={"storage": storage})
Expand Down
Loading

0 comments on commit fd8cfc2

Please sign in to comment.