Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH]: Improve BasePreprocessor and fMRIPrepConfoundRemover #260

Merged
merged 10 commits into from
Oct 18, 2023
1 change: 1 addition & 0 deletions docs/changes/newsfragments/260.enh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve :class:`.BasePreprocessor` for easy subclassing and adapt :class:`.fMRIPrepConfoundRemover` to it by `Synchon Mandal`_
104 changes: 62 additions & 42 deletions junifer/preprocess/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,40 @@ class BasePreprocessor(ABC, PipelineStepMixin, UpdateMetaMixin):
on : str or list of str, optional
The kind of data to apply the preprocessor to. If None,
will work on all available data (default None).
required_data_types : str or list of str, optional
The kind of data types needed for computation. If None,
will be equal to ``on`` (default None).

Raises
------
ValueError
If required input data type(s) is(are) not found.

"""

def __init__(
self,
on: Optional[Union[List[str], str]] = None,
required_data_types: Optional[Union[List[str], str]] = None,
) -> None:
"""Initialize the class."""
# Use all data types if not provided
if on is None:
on = self.get_valid_inputs()
# Convert data types to list
if not isinstance(on, list):
on = [on]

# Check if required inputs are found
if any(x not in self.get_valid_inputs() for x in on):
name = self.__class__.__name__
wrong_on = [x for x in on if x not in self.get_valid_inputs()]
raise ValueError(f"{name} cannot be computed on {wrong_on}")
raise_error(f"{name} cannot be computed on {wrong_on}")
self._on = on
# Set required data types for validation
if required_data_types is None:
self._required_data_types = on
else:
self._required_data_types = required_data_types

def validate_input(self, input: List[str]) -> List[str]:
"""Validate input.
Expand All @@ -55,15 +73,32 @@ def validate_input(self, input: List[str]) -> List[str]:
------
ValueError
If the input does not have the required data.

"""
if not any(x in input for x in self._on):
if any(x not in input for x in self._required_data_types):
raise_error(
"Input does not have the required data."
f"\t Input: {input}"
f"\t Required (any of): {self._on}"
f"\t Required (all of): {self._required_data_types}"
)
return [x for x in self._on if x in input]

@abstractmethod
def get_valid_inputs(self) -> List[str]:
"""Get valid data types for input.

Returns
-------
list of str
The list of data types that can be used as input for this
preprocessor.

"""
raise_error(
msg="Concrete classes need to implement get_valid_inputs().",
klass=NotImplementedError,
)

@abstractmethod
def get_output_type(self, input: List[str]) -> List[str]:
"""Get output type.
Expand All @@ -87,17 +122,34 @@ def get_output_type(self, input: List[str]) -> List[str]:
)

@abstractmethod
def get_valid_inputs(self) -> List[str]:
"""Get valid data types for input.
def preprocess(
self,
input: Dict[str, Any],
extra_input: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict[str, Any]]:
"""Preprocess.

Parameters
----------
input : dict
A single input from the Junifer Data object to preprocess.
extra_input : dict, optional
The other fields in the Junifer Data object. Useful for accessing
other data kind that needs to be used in the computation. For
example, the confound removers can make use of the
confounds if available (default None).

Returns
-------
list of str
The list of data types that can be used as input for this
preprocessor.
str
The key to store the output in the Junifer Data object.
dict
The computed result as dictionary. This will be stored in the
Junifer Data object under the key ``data`` of the data type.

"""
raise_error(
msg="Concrete classes need to implement get_valid_inputs().",
msg="Concrete classes need to implement preprocess().",
klass=NotImplementedError,
)

Expand Down Expand Up @@ -146,35 +198,3 @@ def _fit_transform(

self.update_meta(out[key], "preprocess")
return out

@abstractmethod
def preprocess(
self,
input: Dict[str, Any],
extra_input: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict[str, Any]]:
"""Preprocess.

Parameters
----------
input : dict
A single input from the Junifer Data object to preprocess.
extra_input : dict, optional
The other fields in the Junifer Data object. Useful for accessing
other data kind that needs to be used in the computation. For
example, the confound removers can make use of the
confounds if available (default None).

Returns
-------
key : str
The key to store the output in the Junifer Data object.
object : dict
The computed result as dictionary. This will be stored in the
Junifer Data object under the key 'key'.

"""
raise_error(
msg="Concrete classes need to implement preprocess().",
klass=NotImplementedError,
)
95 changes: 43 additions & 52 deletions junifer/preprocess/confounds/fmriprep_confound_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(
t_r: Optional[float] = None,
masks: Union[str, Dict, List[Union[Dict, str]], None] = None,
) -> None:
"""Initialise the class."""
"""Initialize the class."""
if strategy is None:
strategy = {
"motion": "full",
Expand Down Expand Up @@ -208,48 +208,30 @@ def __init__(
"include it in the future",
klass=ValueError,
)
super().__init__()

def validate_input(self, input: List[str]) -> List[str]:
"""Validate the input to the pipeline step.
super().__init__(
on="BOLD", required_data_types=["BOLD", "BOLD_confounds"]
)

Parameters
----------
input : list of str
The input to the pipeline step. The list must contain the
available Junifer Data object keys.
def get_valid_inputs(self) -> List[str]:
"""Get valid data types for input.

Returns
-------
list of str
The actual elements of the input that will be processed by this
pipeline step.

Raises
------
ValueError
If the input does not have the required data.
The list of data types that can be used as input for this
preprocessor.

"""
_required_inputs = ["BOLD", "BOLD_confounds"]
if any(x not in input for x in _required_inputs):
raise_error(
msg="Input does not have the required data. \n"
f"Input: {input} \n"
f"Required (all off): {_required_inputs} \n",
klass=ValueError,
)

return [x for x in self._on if x in input]
return ["BOLD"]

def get_output_type(self, input: List[str]) -> List[str]:
"""Get the kind of the pipeline step.
"""Get output type.

Parameters
----------
input : list of str
The input to the pipeline step. The list must contain the
available Junifer Data object keys.
The input to the preprocessor. The list must contain the
available Junifer Data dictionary keys.

Returns
-------
Expand All @@ -261,17 +243,6 @@ def get_output_type(self, input: List[str]) -> List[str]:
# Does not add any new keys
return input

def get_valid_inputs(self) -> List[str]:
"""Get the valid inputs for the pipeline step.

Returns
-------
list of str
The valid inputs for the pipeline step.

"""
return ["BOLD"]

def _map_adhoc_to_fmriprep(self, input: Dict[str, Any]) -> None:
"""Map the adhoc format to the fmpriprep format spec.

Expand Down Expand Up @@ -333,6 +304,11 @@ def _process_fmriprep_spec(
spike_name : str
Name of the confound to use for spike detection

Raises
------
ValueError
If invalid confounds file is found.

"""
confounds_df = input["data"]
available_vars = confounds_df.columns
Expand All @@ -347,7 +323,7 @@ def _process_fmriprep_spec(

if any(x not in available_vars for x in t_basics):
missing = [x for x in t_basics if x not in available_vars]
raise ValueError(
raise_error(
"Invalid confounds file. Missing basic confounds: "
f"{missing}. "
"Check if this file is really an fmriprep confounds file. "
Expand Down Expand Up @@ -377,7 +353,7 @@ def _process_fmriprep_spec(
spike_name = "framewise_displacement"
if self.spike is not None:
if spike_name not in available_vars:
raise ValueError(
raise_error(
"Invalid confounds file. Missing framewise_displacement "
"(spike) confound. "
"Check if this file is really an fmriprep confounds file. "
Expand Down Expand Up @@ -460,17 +436,32 @@ def _validate_data(
Dictionary containing the rest of the Junifer Data object. Must
include the ``BOLD_confounds`` key.

Raises
------
ValueError
If ``extra_input`` is None or
if ``"BOLD_confounds"`` is not found in ``extra_input`` or
if ``"data"`` key is not found in ``"BOLD_confounds"`` or
if ``"data"`` is not pandas.DataFrame or
if image time series and confounds have different lengths or
if ``"format"`` is not found in ``"BOLD_confounds"`` or
if ``format = "adhoc"`` and ``"mappings"`` key or ``"fmriprep"``
key or correct fMRIPrep mappings or required fMRIPrep mappings are
not found or if invalid confounds format is found.

"""

# Bold must be 4D niimg
check_niimg_4d(input["data"])

if extra_input is None:
raise_error("No extra input provided", ValueError)
raise_error(msg="No extra input provided", klass=ValueError)
if "BOLD_confounds" not in extra_input:
raise_error("No BOLD_confounds provided", ValueError)
raise_error(msg="No BOLD_confounds provided", klass=ValueError)
if "data" not in extra_input["BOLD_confounds"]:
raise_error("No BOLD_confounds data provided", ValueError)
raise_error(
msg="No BOLD_confounds data provided", klass=ValueError
)
# Confounds must be a dataframe
if not isinstance(extra_input["BOLD_confounds"]["data"], pd.DataFrame):
raise_error(
Expand Down Expand Up @@ -528,14 +519,14 @@ def _validate_data(
]

if len(missing) > 0:
raise ValueError(
raise_error(
"Invalid confounds file. Missing columns: "
f"{missing}. "
"Check if this file matches the adhoc specification for "
"this dataset."
)
elif t_format != "fmriprep":
raise ValueError(f"Invalid confounds format {t_format}")
raise_error(f"Invalid confounds format {t_format}")

def _remove_confounds(
self,
Expand Down Expand Up @@ -614,18 +605,18 @@ def preprocess(
Parameters
----------
input : dict
A single input from the Junifer Data object in which to preprocess.
A single input from the Junifer Data object to preprocess.
extra_input : dict, optional
The other fields in the Junifer Data object. Must include the
``BOLD_confounds`` key.

Returns
-------
key : str
str
The key to store the output in the Junifer Data object.
object : dict
dict
The computed result as dictionary. This will be stored in the
Junifer Data object under the key ``key``.
Junifer Data object under the key ``data`` of the data type.

"""
self._validate_data(input, extra_input)
Expand Down
Loading
Loading