Skip to content

Commit

Permalink
FC using spheres formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kaurao authored and fraimondo committed Sep 15, 2022
1 parent 28b564a commit 6df99e1
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 50 deletions.
6 changes: 3 additions & 3 deletions junifer/data/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
# License: AGPL

from pathlib import Path
from typing import Dict, List, Union, Optional, Tuple
import typing
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import pandas as pd
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike

from ..utils.logging import logger, raise_error
Expand Down
3 changes: 1 addition & 2 deletions junifer/data/tests/test_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
# License: AGPL

import pytest

import numpy as np
import pytest
from numpy.testing import assert_array_equal

from junifer.data.coordinates import (
Expand Down
121 changes: 81 additions & 40 deletions junifer/markers/functional_connectivity_spheres.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from sklearn.covariance import EmpiricalCovariance

from ..api.decorators import register_marker
from ..data.coordinates import load_coordinates
from ..utils import logger, raise_error
from .base import BaseMarker
from ..data.coordinates import load_coordinates


@register_marker
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(
self.coords = coords
self.radius = radius
if radius is None or radius <= 0:
raise_error(f'radius should be > 0: provided {radius}')
raise_error(f"radius should be > 0: provided {radius}")
self.agg_method = agg_method
self.agg_method_params = (
{} if agg_method_params is None else agg_method_params
Expand All @@ -57,9 +57,7 @@ def __init__(
self.cor_method_params = (
{} if cor_method_params is None else cor_method_params
)
self.preproc_params = (
{} if preproc_params is None else preproc_params
)
self.preproc_params = {} if preproc_params is None else preproc_params
on = ["BOLD"]
# default to nilearn behavior
self.cor_method_params["empirical"] = self.cor_method_params.get(
Expand All @@ -68,7 +66,6 @@ def __init__(

super().__init__(on=on, name=name)


def get_output_kind(self, input: List[str]) -> List[str]:
"""Get output kind.
Expand Down Expand Up @@ -106,48 +103,92 @@ def compute(self, input: Dict) -> Dict:
"""
coords, labels = load_coordinates(self.coords)

# allow_overlap=False, smoothing_fwhm=None, standardize=False,
# standardize_confounds=True, high_variance_confounds=False,
# allow_overlap=False, smoothing_fwhm=None, standardize=False,
# standardize_confounds=True, high_variance_confounds=False,
# detrend=False, low_pass=None, high_pass=None, t_r=None
mask_img = (None if self.preproc_params.get('mask_img')
is None else self.preproc_params['mask_img'])
allow_overlap = (True if self.preproc_params.get('allow_overlap')
is None else self.preproc_params['allow_overlap'])
smoothing_fwhm = (None if self.preproc_params.get('smoothing_fwhm')
is None else self.preproc_params['smoothing_fwhm'])
standardize = (False if self.preproc_params.get('standardize')
is None else self.preproc_params['standardize'])
standardize_confounds = (True if self.preproc_params.get('standardize_confounds')
is None else self.preproc_params['standardize_confounds'])
high_variance_confounds = (False if self.preproc_params.get('high_variance_confounds')
is None else self.preproc_params['high_variance_confounds'])
detrend = (False if self.preproc_params.get('detrend')
is None else self.preproc_params['detrend'])
low_pass = (None if self.preproc_params.get('low_pass')
is None else self.preproc_params['low_pass'])
high_pass = (None if self.preproc_params.get('high_pass')
is None else self.preproc_params['high_pass'])
t_r = (None if self.preproc_params.get('t_r')
is None else self.preproc_params['t_r'])
mask_img = (
None
if self.preproc_params.get("mask_img") is None
else self.preproc_params["mask_img"]
)
allow_overlap = (
True
if self.preproc_params.get("allow_overlap") is None
else self.preproc_params["allow_overlap"]
)
smoothing_fwhm = (
None
if self.preproc_params.get("smoothing_fwhm") is None
else self.preproc_params["smoothing_fwhm"]
)
standardize = (
False
if self.preproc_params.get("standardize") is None
else self.preproc_params["standardize"]
)
standardize_confounds = (
True
if self.preproc_params.get("standardize_confounds") is None
else self.preproc_params["standardize_confounds"]
)
high_variance_confounds = (
False
if self.preproc_params.get("high_variance_confounds") is None
else self.preproc_params["high_variance_confounds"]
)
detrend = (
False
if self.preproc_params.get("detrend") is None
else self.preproc_params["detrend"]
)
low_pass = (
None
if self.preproc_params.get("low_pass") is None
else self.preproc_params["low_pass"]
)
high_pass = (
None
if self.preproc_params.get("high_pass") is None
else self.preproc_params["high_pass"]
)
t_r = (
None
if self.preproc_params.get("t_r") is None
else self.preproc_params["t_r"]
)
# params for fit_transform
confounds = (None if self.preproc_params.get('confounds')
is None else self.preproc_params['confounds'])
sample_mask = (None if self.preproc_params.get('sample_mask')
is None else self.preproc_params['sample_mask'])
confounds = (
None
if self.preproc_params.get("confounds") is None
else self.preproc_params["confounds"]
)
sample_mask = (
None
if self.preproc_params.get("sample_mask") is None
else self.preproc_params["sample_mask"]
)

masker = NiftiSpheresMasker(
coords, self.radius,
mask_img=mask_img, allow_overlap=allow_overlap,
smoothing_fwhm=smoothing_fwhm, standardize=standardize,
standardize_confounds=standardize_confounds, high_variance_confounds=high_variance_confounds,
detrend=detrend, low_pass=low_pass, high_pass=high_pass, t_r=t_r
coords,
self.radius,
mask_img=mask_img,
allow_overlap=allow_overlap,
smoothing_fwhm=smoothing_fwhm,
standardize=standardize,
standardize_confounds=standardize_confounds,
high_variance_confounds=high_variance_confounds,
detrend=detrend,
low_pass=low_pass,
high_pass=high_pass,
t_r=t_r,
)
# get the 2D timeseries
if confounds is None:
ts = masker.fit_transform(input['data'], sample_mask=sample_mask)
ts = masker.fit_transform(input["data"], sample_mask=sample_mask)
else:
ts = masker.fit_transform(input['data'], sample_mask=sample_mask,
confounds=[confounds])[0]
ts = masker.fit_transform(
input["data"], sample_mask=sample_mask, confounds=[confounds]
)[0]
if self.cor_method_params["empirical"]:
cm = ConnectivityMeasure(
cov_estimator=EmpiricalCovariance(), kind=self.cor_method
Expand Down
11 changes: 7 additions & 4 deletions junifer/markers/tests/test_functional_connectivity_spheres.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# License: AGPL

from nilearn import datasets, image
from numpy.testing import assert_array_almost_equal, assert_array_equal

from junifer.markers.functional_connectivity_spheres import FunctionalConnectivitySpheres
from junifer.markers.functional_connectivity_spheres import (
FunctionalConnectivitySpheres,
)


def test_FunctionalConnectivitySpheres() -> None:
Expand All @@ -17,7 +18,9 @@ def test_FunctionalConnectivitySpheres() -> None:
ni_data = datasets.fetch_spm_auditory(subject_id="sub001")
fmri_img = image.concat_imgs(ni_data.func) # type: ignore

fc = FunctionalConnectivitySpheres(coords="DMNBuckner", radius=5.0, cor_method='correlation')
fc = FunctionalConnectivitySpheres(
coords="DMNBuckner", radius=5.0, cor_method="correlation"
)
out = fc.compute({"data": fmri_img})

assert "data" in out
Expand All @@ -26,4 +29,4 @@ def test_FunctionalConnectivitySpheres() -> None:
assert out["data"].shape[0] == 6
assert out["data"].shape[1] == 6
assert len(set(out["row_names"])) == 6
assert len(set(out["col_names"])) == 6
assert len(set(out["col_names"])) == 6
4 changes: 3 additions & 1 deletion junifer/testing/tests/test_testing_registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Provide tests for testing registry."""

from junifer.api.registry import get_step_names
import importlib

from junifer.api.registry import get_step_names


def test_testing_registry() -> None:
"""Test testing registry."""
import junifer

importlib.reload(junifer.api.registry)
importlib.reload(junifer)

Expand Down

0 comments on commit 6df99e1

Please sign in to comment.