diff --git a/junifer/datagrabber/datalad_base.py b/junifer/datagrabber/datalad_base.py index 7bbc67f972..918d5dc17d 100644 --- a/junifer/datagrabber/datalad_base.py +++ b/junifer/datagrabber/datalad_base.py @@ -124,9 +124,7 @@ def _dataset_get(self, out: Dict) -> Dict: def install(self) -> None: """Install the datalad dataset into the datadir.""" logger.debug(f"Installing dataset {self.uri} to {self._datadir}") - self._dataset: dl.Dataset = dl.clone( - self.uri, self._datadir - ) + self._dataset: dl.Dataset = dl.clone(self.uri, self._datadir) logger.debug("Dataset installed") def remove(self): diff --git a/junifer/markers/functional_connectivity_atlas.py b/junifer/markers/functional_connectivity_atlas.py index b2b4ff7548..6e1253824d 100644 --- a/junifer/markers/functional_connectivity_atlas.py +++ b/junifer/markers/functional_connectivity_atlas.py @@ -5,6 +5,7 @@ # License: AGPL from typing import Dict, List + from nilearn.connectome import ConnectivityMeasure from sklearn.covariance import EmpiricalCovariance @@ -30,21 +31,29 @@ class FunctionalConnectivityAtlas(BaseMarker): """ def __init__( - self, atlas, agg_method='mean', agg_method_params=None, - cor_method='covariance', cor_method_params=None, name=None + self, + atlas, + agg_method="mean", + agg_method_params=None, + cor_method="covariance", + cor_method_params=None, + name=None, ) -> None: """Initialize the class.""" self.atlas = atlas self.agg_method = agg_method - self.agg_method_params = {} if agg_method_params is None \ - else agg_method_params + self.agg_method_params = ( + {} if agg_method_params is None else agg_method_params + ) self.cor_method = cor_method - self.cor_method_params = {} if cor_method_params is None \ - else cor_method_params + self.cor_method_params = ( + {} if cor_method_params is None else cor_method_params + ) on = ["BOLD"] # default to nilearn behavior - self.cor_method_params['empirical'] = self.cor_method_params.get( - 'empirical', False) + self.cor_method_params["empirical"] = self.cor_method_params.get( + "empirical", False + ) super().__init__(on=on, name=name) @@ -120,21 +129,25 @@ def compute(self, input: Dict) -> Dict: Returns ------- - A dict with + A dict with FC matrix as a 2D numpy array. Row names as a list. Col names as a list. """ - pa = ParcelAggregation(atlas=self.atlas, method=self.agg_method, - method_params=self.agg_method_params, - on="BOLD") + pa = ParcelAggregation( + atlas=self.atlas, + method=self.agg_method, + method_params=self.agg_method_params, + on="BOLD", + ) # get the 2D timeseries after parcel aggregation ts = pa.compute(input) - if self.cor_method_params['empirical']: - cm = ConnectivityMeasure(cov_estimator=EmpiricalCovariance(), - kind=self.cor_method) + if self.cor_method_params["empirical"]: + cm = ConnectivityMeasure( + cov_estimator=EmpiricalCovariance(), kind=self.cor_method + ) else: cm = ConnectivityMeasure(kind=self.cor_method) out = {} diff --git a/junifer/markers/tests/test_functional_connectivity_atlas.py b/junifer/markers/tests/test_functional_connectivity_atlas.py index 8b28adc04b..1f5f3724ca 100644 --- a/junifer/markers/tests/test_functional_connectivity_atlas.py +++ b/junifer/markers/tests/test_functional_connectivity_atlas.py @@ -5,49 +5,50 @@ # License: AGPL from nilearn import datasets, image -from nilearn.maskers import NiftiLabelsMasker from nilearn.connectome import ConnectivityMeasure +from nilearn.maskers import NiftiLabelsMasker from numpy.testing import assert_array_almost_equal, assert_array_equal +from junifer.markers.functional_connectivity_atlas import ( + FunctionalConnectivityAtlas, +) from junifer.markers.parcel import ParcelAggregation -from junifer.markers.functional_connectivity_atlas \ - import FunctionalConnectivityAtlas def test_FunctionalConnectivityAtlas() -> None: """Test FunctionalConnectivityAtlas.""" # get a dataset - ni_data = datasets.fetch_spm_auditory(subject_id='sub001') + ni_data = datasets.fetch_spm_auditory(subject_id="sub001") fmri_img = image.concat_imgs(ni_data.func) # type: ignore - fc = FunctionalConnectivityAtlas(atlas='Schaefer100x7') - out = fc.compute({'data': fmri_img}) + fc = FunctionalConnectivityAtlas(atlas="Schaefer100x7") + out = fc.compute({"data": fmri_img}) - assert 'data' in out - assert 'row_names' in out - assert 'col_names' in out - assert out['data'].shape[0] == 100 - assert out['data'].shape[1] == 100 - assert len(set(out['row_names'])) == 100 - assert len(set(out['col_names'])) == 100 + assert "data" in out + assert "row_names" in out + assert "col_names" in out + assert out["data"].shape[0] == 100 + assert out["data"].shape[1] == 100 + assert len(set(out["row_names"])) == 100 + assert len(set(out["col_names"])) == 100 # get the timeseries using pa - pa = ParcelAggregation(atlas='Schaefer100x7', method='mean', - on="BOLD") + pa = ParcelAggregation(atlas="Schaefer100x7", method="mean", on="BOLD") ts = pa.compute({"data": fmri_img}) # compare with nilearn # Get the testing atlas (for nilearn) - atlas = datasets.fetch_atlas_schaefer_2018(n_rois=100, yeo_networks=7, - resolution_mm=2) - masker = NiftiLabelsMasker(labels_img=atlas['maps'], standardize=False) + atlas = datasets.fetch_atlas_schaefer_2018( + n_rois=100, yeo_networks=7, resolution_mm=2 + ) + masker = NiftiLabelsMasker(labels_img=atlas["maps"], standardize=False) ts_ni = masker.fit_transform(fmri_img) # check the TS are almost equal - assert_array_equal(ts_ni, ts['data']) + assert_array_equal(ts_ni, ts["data"]) # Check that FC are almost equal - cm = ConnectivityMeasure(kind='covariance') + cm = ConnectivityMeasure(kind="covariance") out_ni = cm.fit_transform([ts_ni])[0] - assert_array_almost_equal(out_ni, out['data'], decimal=3) + assert_array_almost_equal(out_ni, out["data"], decimal=3) diff --git a/junifer/storage/tests/test_sqlite.py b/junifer/storage/tests/test_sqlite.py index d7eb6d7862..0d3850f128 100644 --- a/junifer/storage/tests/test_sqlite.py +++ b/junifer/storage/tests/test_sqlite.py @@ -8,9 +8,9 @@ from typing import List, Union import numpy as np -from numpy.testing import assert_array_equal import pandas as pd import pytest +from numpy.testing import assert_array_equal from pandas.testing import assert_frame_equal from sqlalchemy import create_engine diff --git a/junifer/testing/datagrabbers.py b/junifer/testing/datagrabbers.py index 6d20576755..f591fc4b43 100644 --- a/junifer/testing/datagrabbers.py +++ b/junifer/testing/datagrabbers.py @@ -7,8 +7,8 @@ import tempfile from typing import Dict, List -from nilearn import datasets, image import nibabel as nib +from nilearn import datasets, image from ..datagrabber.base import BaseDataGrabber diff --git a/junifer/testing/tests/test_spmauditory_datagrabber.py b/junifer/testing/tests/test_spmauditory_datagrabber.py index 11851419c2..711712f85b 100644 --- a/junifer/testing/tests/test_spmauditory_datagrabber.py +++ b/junifer/testing/tests/test_spmauditory_datagrabber.py @@ -9,21 +9,21 @@ def test_SPMAuditoryTestingDatagrabber() -> None: """Test SPM Auditory datagrabber.""" expected_elements = [ - 'sub001', - 'sub002', - 'sub003', - 'sub004', - 'sub005', - 'sub006', - 'sub007', - 'sub008', - 'sub009', - 'sub010' + "sub001", + "sub002", + "sub003", + "sub004", + "sub005", + "sub006", + "sub007", + "sub008", + "sub009", + "sub010", ] with SPMAuditoryTestingDatagrabber() as dg: all_elements = dg.get_elements() assert set(all_elements) == set(expected_elements) - out = dg['sub001'] + out = dg["sub001"] assert "BOLD" in out assert out["BOLD"]["path"].exists() assert out["BOLD"]["path"].is_file()