diff --git a/junifer/data/coordinates.py b/junifer/data/coordinates.py index d8c64eee5e..29bd969c46 100644 --- a/junifer/data/coordinates.py +++ b/junifer/data/coordinates.py @@ -3,12 +3,12 @@ # Authors: Federico Raimondo # License: AGPL -from pathlib import Path -from typing import Dict, List, Union, Optional, Tuple import typing +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union -import pandas as pd import numpy as np +import pandas as pd from numpy.typing import ArrayLike from ..utils.logging import logger, raise_error diff --git a/junifer/data/tests/test_coordinates.py b/junifer/data/tests/test_coordinates.py index e263707e08..1ba4134417 100644 --- a/junifer/data/tests/test_coordinates.py +++ b/junifer/data/tests/test_coordinates.py @@ -3,9 +3,8 @@ # Authors: Federico Raimondo # License: AGPL -import pytest - import numpy as np +import pytest from numpy.testing import assert_array_equal from junifer.data.coordinates import ( diff --git a/junifer/markers/functional_connectivity_spheres.py b/junifer/markers/functional_connectivity_spheres.py index 9fed56dc59..6c748521a5 100644 --- a/junifer/markers/functional_connectivity_spheres.py +++ b/junifer/markers/functional_connectivity_spheres.py @@ -11,9 +11,9 @@ from sklearn.covariance import EmpiricalCovariance from ..api.decorators import register_marker +from ..data.coordinates import load_coordinates from ..utils import logger, raise_error from .base import BaseMarker -from ..data.coordinates import load_coordinates @register_marker @@ -48,7 +48,7 @@ def __init__( self.coords = coords self.radius = radius if radius is None or radius <= 0: - raise_error(f'radius should be > 0: provided {radius}') + raise_error(f"radius should be > 0: provided {radius}") self.agg_method = agg_method self.agg_method_params = ( {} if agg_method_params is None else agg_method_params @@ -57,9 +57,7 @@ def __init__( self.cor_method_params = ( {} if cor_method_params is None else cor_method_params ) - self.preproc_params = ( - {} if preproc_params is None else preproc_params - ) + self.preproc_params = {} if preproc_params is None else preproc_params on = ["BOLD"] # default to nilearn behavior self.cor_method_params["empirical"] = self.cor_method_params.get( @@ -68,7 +66,6 @@ def __init__( super().__init__(on=on, name=name) - def get_output_kind(self, input: List[str]) -> List[str]: """Get output kind. @@ -106,48 +103,92 @@ def compute(self, input: Dict) -> Dict: """ coords, labels = load_coordinates(self.coords) - # allow_overlap=False, smoothing_fwhm=None, standardize=False, - # standardize_confounds=True, high_variance_confounds=False, + # allow_overlap=False, smoothing_fwhm=None, standardize=False, + # standardize_confounds=True, high_variance_confounds=False, # detrend=False, low_pass=None, high_pass=None, t_r=None - mask_img = (None if self.preproc_params.get('mask_img') - is None else self.preproc_params['mask_img']) - allow_overlap = (True if self.preproc_params.get('allow_overlap') - is None else self.preproc_params['allow_overlap']) - smoothing_fwhm = (None if self.preproc_params.get('smoothing_fwhm') - is None else self.preproc_params['smoothing_fwhm']) - standardize = (False if self.preproc_params.get('standardize') - is None else self.preproc_params['standardize']) - standardize_confounds = (True if self.preproc_params.get('standardize_confounds') - is None else self.preproc_params['standardize_confounds']) - high_variance_confounds = (False if self.preproc_params.get('high_variance_confounds') - is None else self.preproc_params['high_variance_confounds']) - detrend = (False if self.preproc_params.get('detrend') - is None else self.preproc_params['detrend']) - low_pass = (None if self.preproc_params.get('low_pass') - is None else self.preproc_params['low_pass']) - high_pass = (None if self.preproc_params.get('high_pass') - is None else self.preproc_params['high_pass']) - t_r = (None if self.preproc_params.get('t_r') - is None else self.preproc_params['t_r']) + mask_img = ( + None + if self.preproc_params.get("mask_img") is None + else self.preproc_params["mask_img"] + ) + allow_overlap = ( + True + if self.preproc_params.get("allow_overlap") is None + else self.preproc_params["allow_overlap"] + ) + smoothing_fwhm = ( + None + if self.preproc_params.get("smoothing_fwhm") is None + else self.preproc_params["smoothing_fwhm"] + ) + standardize = ( + False + if self.preproc_params.get("standardize") is None + else self.preproc_params["standardize"] + ) + standardize_confounds = ( + True + if self.preproc_params.get("standardize_confounds") is None + else self.preproc_params["standardize_confounds"] + ) + high_variance_confounds = ( + False + if self.preproc_params.get("high_variance_confounds") is None + else self.preproc_params["high_variance_confounds"] + ) + detrend = ( + False + if self.preproc_params.get("detrend") is None + else self.preproc_params["detrend"] + ) + low_pass = ( + None + if self.preproc_params.get("low_pass") is None + else self.preproc_params["low_pass"] + ) + high_pass = ( + None + if self.preproc_params.get("high_pass") is None + else self.preproc_params["high_pass"] + ) + t_r = ( + None + if self.preproc_params.get("t_r") is None + else self.preproc_params["t_r"] + ) # params for fit_transform - confounds = (None if self.preproc_params.get('confounds') - is None else self.preproc_params['confounds']) - sample_mask = (None if self.preproc_params.get('sample_mask') - is None else self.preproc_params['sample_mask']) + confounds = ( + None + if self.preproc_params.get("confounds") is None + else self.preproc_params["confounds"] + ) + sample_mask = ( + None + if self.preproc_params.get("sample_mask") is None + else self.preproc_params["sample_mask"] + ) masker = NiftiSpheresMasker( - coords, self.radius, - mask_img=mask_img, allow_overlap=allow_overlap, - smoothing_fwhm=smoothing_fwhm, standardize=standardize, - standardize_confounds=standardize_confounds, high_variance_confounds=high_variance_confounds, - detrend=detrend, low_pass=low_pass, high_pass=high_pass, t_r=t_r + coords, + self.radius, + mask_img=mask_img, + allow_overlap=allow_overlap, + smoothing_fwhm=smoothing_fwhm, + standardize=standardize, + standardize_confounds=standardize_confounds, + high_variance_confounds=high_variance_confounds, + detrend=detrend, + low_pass=low_pass, + high_pass=high_pass, + t_r=t_r, ) # get the 2D timeseries if confounds is None: - ts = masker.fit_transform(input['data'], sample_mask=sample_mask) + ts = masker.fit_transform(input["data"], sample_mask=sample_mask) else: - ts = masker.fit_transform(input['data'], sample_mask=sample_mask, - confounds=[confounds])[0] + ts = masker.fit_transform( + input["data"], sample_mask=sample_mask, confounds=[confounds] + )[0] if self.cor_method_params["empirical"]: cm = ConnectivityMeasure( cov_estimator=EmpiricalCovariance(), kind=self.cor_method diff --git a/junifer/markers/tests/test_functional_connectivity_spheres.py b/junifer/markers/tests/test_functional_connectivity_spheres.py index 5a50acec16..14bdc3b86f 100644 --- a/junifer/markers/tests/test_functional_connectivity_spheres.py +++ b/junifer/markers/tests/test_functional_connectivity_spheres.py @@ -5,9 +5,10 @@ # License: AGPL from nilearn import datasets, image -from numpy.testing import assert_array_almost_equal, assert_array_equal -from junifer.markers.functional_connectivity_spheres import FunctionalConnectivitySpheres +from junifer.markers.functional_connectivity_spheres import ( + FunctionalConnectivitySpheres, +) def test_FunctionalConnectivitySpheres() -> None: @@ -17,7 +18,9 @@ def test_FunctionalConnectivitySpheres() -> None: ni_data = datasets.fetch_spm_auditory(subject_id="sub001") fmri_img = image.concat_imgs(ni_data.func) # type: ignore - fc = FunctionalConnectivitySpheres(coords="DMNBuckner", radius=5.0, cor_method='correlation') + fc = FunctionalConnectivitySpheres( + coords="DMNBuckner", radius=5.0, cor_method="correlation" + ) out = fc.compute({"data": fmri_img}) assert "data" in out @@ -26,4 +29,4 @@ def test_FunctionalConnectivitySpheres() -> None: assert out["data"].shape[0] == 6 assert out["data"].shape[1] == 6 assert len(set(out["row_names"])) == 6 - assert len(set(out["col_names"])) == 6 \ No newline at end of file + assert len(set(out["col_names"])) == 6 diff --git a/junifer/testing/tests/test_testing_registry.py b/junifer/testing/tests/test_testing_registry.py index 59a0492833..090b6709f6 100644 --- a/junifer/testing/tests/test_testing_registry.py +++ b/junifer/testing/tests/test_testing_registry.py @@ -1,12 +1,14 @@ """Provide tests for testing registry.""" -from junifer.api.registry import get_step_names import importlib +from junifer.api.registry import get_step_names + def test_testing_registry() -> None: """Test testing registry.""" import junifer + importlib.reload(junifer.api.registry) importlib.reload(junifer)