Skip to content

Commit

Permalink
formatting fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
kaurao committed Sep 14, 2022
1 parent 6221205 commit 8370d9e
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 52 deletions.
4 changes: 1 addition & 3 deletions junifer/datagrabber/datalad_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ def _dataset_get(self, out: Dict) -> Dict:
def install(self) -> None:
"""Install the datalad dataset into the datadir."""
logger.debug(f"Installing dataset {self.uri} to {self._datadir}")
self._dataset: dl.Dataset = dl.clone(
self.uri, self._datadir
)
self._dataset: dl.Dataset = dl.clone(self.uri, self._datadir)
logger.debug("Dataset installed")

def remove(self):
Expand Down
43 changes: 28 additions & 15 deletions junifer/markers/functional_connectivity_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# License: AGPL

from typing import Dict, List

from nilearn.connectome import ConnectivityMeasure
from sklearn.covariance import EmpiricalCovariance

Expand All @@ -30,21 +31,29 @@ class FunctionalConnectivityAtlas(BaseMarker):
"""

def __init__(
self, atlas, agg_method='mean', agg_method_params=None,
cor_method='covariance', cor_method_params=None, name=None
self,
atlas,
agg_method="mean",
agg_method_params=None,
cor_method="covariance",
cor_method_params=None,
name=None,
) -> None:
"""Initialize the class."""
self.atlas = atlas
self.agg_method = agg_method
self.agg_method_params = {} if agg_method_params is None \
else agg_method_params
self.agg_method_params = (
{} if agg_method_params is None else agg_method_params
)
self.cor_method = cor_method
self.cor_method_params = {} if cor_method_params is None \
else cor_method_params
self.cor_method_params = (
{} if cor_method_params is None else cor_method_params
)
on = ["BOLD"]
# default to nilearn behavior
self.cor_method_params['empirical'] = self.cor_method_params.get(
'empirical', False)
self.cor_method_params["empirical"] = self.cor_method_params.get(
"empirical", False
)

super().__init__(on=on, name=name)

Expand Down Expand Up @@ -120,21 +129,25 @@ def compute(self, input: Dict) -> Dict:
Returns
-------
A dict with
A dict with
FC matrix as a 2D numpy array.
Row names as a list.
Col names as a list.
"""
pa = ParcelAggregation(atlas=self.atlas, method=self.agg_method,
method_params=self.agg_method_params,
on="BOLD")
pa = ParcelAggregation(
atlas=self.atlas,
method=self.agg_method,
method_params=self.agg_method_params,
on="BOLD",
)
# get the 2D timeseries after parcel aggregation
ts = pa.compute(input)

if self.cor_method_params['empirical']:
cm = ConnectivityMeasure(cov_estimator=EmpiricalCovariance(),
kind=self.cor_method)
if self.cor_method_params["empirical"]:
cm = ConnectivityMeasure(
cov_estimator=EmpiricalCovariance(), kind=self.cor_method
)
else:
cm = ConnectivityMeasure(kind=self.cor_method)
out = {}
Expand Down
43 changes: 22 additions & 21 deletions junifer/markers/tests/test_functional_connectivity_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,50 @@
# License: AGPL

from nilearn import datasets, image
from nilearn.maskers import NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure
from nilearn.maskers import NiftiLabelsMasker
from numpy.testing import assert_array_almost_equal, assert_array_equal

from junifer.markers.functional_connectivity_atlas import (
FunctionalConnectivityAtlas,
)
from junifer.markers.parcel import ParcelAggregation
from junifer.markers.functional_connectivity_atlas \
import FunctionalConnectivityAtlas


def test_FunctionalConnectivityAtlas() -> None:
"""Test FunctionalConnectivityAtlas."""

# get a dataset
ni_data = datasets.fetch_spm_auditory(subject_id='sub001')
ni_data = datasets.fetch_spm_auditory(subject_id="sub001")
fmri_img = image.concat_imgs(ni_data.func) # type: ignore

fc = FunctionalConnectivityAtlas(atlas='Schaefer100x7')
out = fc.compute({'data': fmri_img})
fc = FunctionalConnectivityAtlas(atlas="Schaefer100x7")
out = fc.compute({"data": fmri_img})

assert 'data' in out
assert 'row_names' in out
assert 'col_names' in out
assert out['data'].shape[0] == 100
assert out['data'].shape[1] == 100
assert len(set(out['row_names'])) == 100
assert len(set(out['col_names'])) == 100
assert "data" in out
assert "row_names" in out
assert "col_names" in out
assert out["data"].shape[0] == 100
assert out["data"].shape[1] == 100
assert len(set(out["row_names"])) == 100
assert len(set(out["col_names"])) == 100

# get the timeseries using pa
pa = ParcelAggregation(atlas='Schaefer100x7', method='mean',
on="BOLD")
pa = ParcelAggregation(atlas="Schaefer100x7", method="mean", on="BOLD")
ts = pa.compute({"data": fmri_img})

# compare with nilearn
# Get the testing atlas (for nilearn)
atlas = datasets.fetch_atlas_schaefer_2018(n_rois=100, yeo_networks=7,
resolution_mm=2)
masker = NiftiLabelsMasker(labels_img=atlas['maps'], standardize=False)
atlas = datasets.fetch_atlas_schaefer_2018(
n_rois=100, yeo_networks=7, resolution_mm=2
)
masker = NiftiLabelsMasker(labels_img=atlas["maps"], standardize=False)
ts_ni = masker.fit_transform(fmri_img)

# check the TS are almost equal
assert_array_equal(ts_ni, ts['data'])
assert_array_equal(ts_ni, ts["data"])

# Check that FC are almost equal
cm = ConnectivityMeasure(kind='covariance')
cm = ConnectivityMeasure(kind="covariance")
out_ni = cm.fit_transform([ts_ni])[0]
assert_array_almost_equal(out_ni, out['data'], decimal=3)
assert_array_almost_equal(out_ni, out["data"], decimal=3)
2 changes: 1 addition & 1 deletion junifer/storage/tests/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from typing import List, Union

import numpy as np
from numpy.testing import assert_array_equal
import pandas as pd
import pytest
from numpy.testing import assert_array_equal
from pandas.testing import assert_frame_equal
from sqlalchemy import create_engine

Expand Down
2 changes: 1 addition & 1 deletion junifer/testing/datagrabbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import tempfile
from typing import Dict, List

from nilearn import datasets, image
import nibabel as nib
from nilearn import datasets, image

from ..datagrabber.base import BaseDataGrabber

Expand Down
22 changes: 11 additions & 11 deletions junifer/testing/tests/test_spmauditory_datagrabber.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@
def test_SPMAuditoryTestingDatagrabber() -> None:
"""Test SPM Auditory datagrabber."""
expected_elements = [
'sub001',
'sub002',
'sub003',
'sub004',
'sub005',
'sub006',
'sub007',
'sub008',
'sub009',
'sub010'
"sub001",
"sub002",
"sub003",
"sub004",
"sub005",
"sub006",
"sub007",
"sub008",
"sub009",
"sub010",
]
with SPMAuditoryTestingDatagrabber() as dg:
all_elements = dg.get_elements()
assert set(all_elements) == set(expected_elements)
out = dg['sub001']
out = dg["sub001"]
assert "BOLD" in out
assert out["BOLD"]["path"].exists()
assert out["BOLD"]["path"].is_file()
Expand Down

0 comments on commit 8370d9e

Please sign in to comment.