Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH]: Add support for multiple preprocessors in the pipeline #263

Merged
merged 6 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/newsfragments/263.change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Modify ``preprocessor`` to ``preprocessors`` in :func:`.run` and ``preprocessing`` to ``preprocessors`` in :class:`.MarkerCollection` to accept multiple preprocessors by `Synchon Mandal`_
16 changes: 13 additions & 3 deletions junifer/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,31 @@ def run(

"""
configure_logging(level=verbose)
# TODO: add validation
# TODO(synchon): add validation
# Parse YAML
config = parse_yaml(filepath) # type: ignore
# Retrieve working directory
workdir = config["workdir"]
# Fetch datagrabber
datagrabber = config["datagrabber"]
# Fetch markers
markers = config["markers"]
# Fetch storage
storage = config["storage"]
preprocessor = config.get("preprocess")
# Fetch preprocessors
preprocessors = config.get("preprocess")
# Convert to list if single preprocessor
if preprocessors is not None and not isinstance(preprocessors, list):
preprocessors = [preprocessors]
# Parse elements
elements = _parse_elements(element, config)
# Perform operation
api_run(
workdir=workdir,
datagrabber=datagrabber,
markers=markers,
storage=storage,
preprocessor=preprocessor,
preprocessors=preprocessors,
elements=elements,
)

Expand Down
22 changes: 13 additions & 9 deletions junifer/api/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def run(
datagrabber: Dict,
markers: List[Dict],
storage: Dict,
preprocessor: Optional[Dict] = None,
preprocessors: Optional[List[Dict]] = None,
elements: Union[str, List[Union[str, Tuple]], Tuple, None] = None,
) -> None:
"""Run the pipeline on the selected element.
Expand All @@ -104,10 +104,10 @@ def run(
Storage to use. Must have a key ``kind`` with the kind of
storage to use. All other keys are passed to the storage
init function.
preprocessor : dict, optional
Preprocessor to use. Must have a key ``kind`` with the kind of
preprocessor to use. All other keys are passed to the preprocessor
init function (default None).
preprocessors : list of dict, optional
List of preprocessors to use. Each preprocessor is a dict with at
least a key ``kind`` specifying the preprocessor to use. All other keys
are passed to the preprocessor init function (default None).
elements : str or tuple or list of str or tuple, optional
Element(s) to process. Will be used to index the DataGrabber
(default None).
Expand Down Expand Up @@ -152,15 +152,19 @@ def run(
storage_object = typing.cast(BaseFeatureStorage, storage_object)

# Get preprocessor to use (if provided)
if preprocessor is not None:
preprocessor_object = _get_preprocessor(preprocessor)
if preprocessors is not None:
_preprocessors = [x.copy() for x in preprocessors]
built_preprocessors = []
for preprocessor in _preprocessors:
preprocessor_object = _get_preprocessor(preprocessor)
built_preprocessors.append(preprocessor_object)
else:
preprocessor_object = None
built_preprocessors = None

# Create new marker collection
mc = MarkerCollection(
markers=built_markers,
preprocessing=preprocessor_object,
preprocessors=built_preprocessors,
storage=storage_object,
)
# Fit elements
Expand Down
8 changes: 5 additions & 3 deletions junifer/api/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ def test_run_single_element_with_preprocessing(tmp_path: Path) -> None:
}
],
storage=storage,
preprocessor={
"kind": "fMRIPrepConfoundRemover",
},
preprocessors=[
{
"kind": "fMRIPrepConfoundRemover",
}
],
elements=["sub-01"],
)
# Check files
Expand Down
33 changes: 22 additions & 11 deletions junifer/markers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..datareader.default import DefaultDataReader
from ..markers.base import BaseMarker
from ..pipeline import PipelineStepMixin
from ..preprocess.base import BasePreprocessor
from ..storage.base import BaseFeatureStorage
from ..utils import logger

Expand All @@ -27,8 +28,8 @@ class MarkerCollection:
The markers to compute.
datareader : DataReader-like object, optional
The DataReader to use (default None).
preprocessing : preprocessing-like, optional
The preprocessing steps to apply.
preprocessors : list of preprocessing-like, optional
The preprocessors to apply (default None).
storage : storage-like, optional
The storage to use (default None).

Expand All @@ -38,7 +39,7 @@ def __init__(
self,
markers: List[BaseMarker],
datareader: Optional[PipelineStepMixin] = None,
preprocessing: Optional[PipelineStepMixin] = None,
preprocessors: Optional[List[BasePreprocessor]] = None,
storage: Optional[BaseFeatureStorage] = None,
):
# Check that the markers have different names
Expand All @@ -53,7 +54,7 @@ def __init__(
if datareader is None:
datareader = DefaultDataReader()
self._datareader = datareader
self._preprocessing = preprocessing
self._preprocessors = preprocessors
self._storage = storage

def fit(self, input: Dict[str, Dict]) -> Optional[Dict]:
Expand All @@ -79,9 +80,14 @@ def fit(self, input: Dict[str, Dict]) -> Optional[Dict]:
data = self._datareader.fit_transform(input)

# Apply preprocessing steps
if self._preprocessing is not None:
logger.info("Preprocessing data")
data = self._preprocessing.fit_transform(data)
if self._preprocessors is not None:
for preprocessor in self._preprocessors:
logger.info(
"Preprocessing data with "
f"{preprocessor.__class__.__name__}"
)
# Mutate data after every iteration
data = preprocessor.fit_transform(data)

# Compute markers
out = {}
Expand Down Expand Up @@ -116,10 +122,15 @@ def validate(self, datagrabber: "BaseDataGrabber") -> None:
t_data = self._datareader.validate(t_data)
logger.info(f"Data Reader output type: {t_data}")

if self._preprocessing is not None:
logger.info("Validating Preprocessor:")
t_data = self._preprocessing.validate(t_data)
logger.info(f"Preprocess output type: {t_data}")
if self._preprocessors is not None:
for preprocessor in self._preprocessors:
logger.info(
"Validating Preprocessor: "
f"{preprocessor.__class__.__name__}"
)
# Validate preprocessor
t_data = preprocessor.validate(t_data)
logger.info(f"Preprocess output type: {t_data}")

for marker in self._markers:
logger.info(f"Validating Marker: {marker.name}")
Expand Down
8 changes: 4 additions & 4 deletions junifer/markers/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_marker_collection() -> None:
]
mc = MarkerCollection(markers=markers) # type: ignore
assert mc._markers == markers
assert mc._preprocessing is None
assert mc._preprocessors is None
assert mc._storage is None
assert isinstance(mc._datareader, DefaultDataReader)

Expand Down Expand Up @@ -97,7 +97,7 @@ def fit_transform(self, input):

mc2 = MarkerCollection(
markers=markers, # type: ignore
preprocessing=BypassPreprocessing(),
preprocessors=[BypassPreprocessing()], # type: ignore
datareader=DefaultDataReader(),
)
assert isinstance(mc2._datareader, DefaultDataReader)
Expand Down Expand Up @@ -128,10 +128,10 @@ def test_marker_collection_with_preprocessing() -> None:
]
mc = MarkerCollection(
markers=markers, # type: ignore
preprocessing=fMRIPrepConfoundRemover(),
preprocessors=[fMRIPrepConfoundRemover()],
)
assert mc._markers == markers
assert mc._preprocessing is not None
assert mc._preprocessors is not None
assert mc._storage is None
assert isinstance(mc._datareader, DefaultDataReader)

Expand Down