Skip to content

Commit

Permalink
Improve test + implement using SphereAggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
fraimondo committed Oct 5, 2022
1 parent 1e87114 commit 242c5a4
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 122 deletions.
176 changes: 56 additions & 120 deletions junifer/markers/functional_connectivity_spheres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,51 @@
# Kaustubh R. Patil <k.patil@fz-juelich.de>
# License: AGPL

from typing import Dict, List
from typing import Dict, List, Optional

from nilearn.connectome import ConnectivityMeasure
from nilearn.maskers import NiftiSpheresMasker
from sklearn.covariance import EmpiricalCovariance

from ..api.decorators import register_marker
from ..data.coordinates import load_coordinates
from ..utils import logger, raise_error
from .base import BaseMarker
from .sphere_aggregation import SphereAggregation


@register_marker
class FunctionalConnectivitySpheres(BaseMarker):
"""Class for functional connectivity.
"""Class for functional connectivity using coordinates (spheres).
Parameters
Parameters
----------
seeds
mask_img
agg_method
agg_method_params
cor_method
cor_method_params
name
# ToDo: add mask_img
coords: str
The name of the coordinates list to use. See
:mod:`junifer.data.coordinates`
radius: float
The radius of the sphere in mm. If None, the signal will be extracted
from a single voxel. See :class:`nilearn.maskers.NiftiSpheresMasker`
for more information.
agg_method: str
The aggregation method to use.
See :func:`junifer.stats.get_aggfunc_by_name` for more information.
agg_method_params: Dict, optional
The parameters to pass to the aggregation method.
name : str, optional
The name of the marker. By default, it will use
KIND_FunctionalConnectivitySpheres where KIND is the kind of data it
was applied to (default None).
"""

def __init__(
self,
coords,
radius,
agg_method="mean",
agg_method_params=None,
cor_method="covariance",
cor_method_params=None,
preproc_params=None,
name=None,
coords: str,
radius: float,
agg_method: str = "mean",
agg_method_params: Optional[Dict] = None,
cor_method: str = "covariance",
cor_method_params: Optional[Dict] = None,
name: Optional[str] = None,
) -> None:
"""Initialize the class."""
self.coords = coords
Expand All @@ -57,7 +63,6 @@ def __init__(
self.cor_method_params = (
{} if cor_method_params is None else cor_method_params
)
self.preproc_params = {} if preproc_params is None else preproc_params
on = ["BOLD"]
# default to nilearn behavior
self.cor_method_params["empirical"] = self.cor_method_params.get(
Expand All @@ -84,121 +89,52 @@ def get_output_kind(self, input: List[str]) -> List[str]:
outputs = ["matrix"]
return outputs

def compute(self, input: Dict) -> Dict:
def compute(self, input: Dict, extra_input: Optional[Dict] = None) -> Dict:
"""Compute.
Parameters
----------
input : Dict[str, Dict]
The input to the pipeline step. The list must contain the
available Junifer Data dictionary keys.
A single input from the pipeline data object in which to compute
the marker.
extra_input : Dict, optional
The other fields in the pipeline data object. Useful for accessing
other data kind that needs to be used in the computation. For
example, the functional connectivity markers can make use of the
confounds if available (default None).
Returns
-------
A dict with
FC matrix as a 2D numpy array.
Row names as a list.
Col names as a list.
dict
The computed result as dictionary. The following data will be
included in the dictionary:
- 'data': FC matrix as a 2D numpy array.
- 'row_names': Row names as a list.
- 'col_names': Col names as a list.
- 'kind': The kind of matrix (tril, triu or full)
"""
coords, labels = load_coordinates(self.coords)

# allow_overlap=False, smoothing_fwhm=None, standardize=False,
# standardize_confounds=True, high_variance_confounds=False,
# detrend=False, low_pass=None, high_pass=None, t_r=None
mask_img = (
None
if self.preproc_params.get("mask_img") is None
else self.preproc_params["mask_img"]
)
allow_overlap = (
True
if self.preproc_params.get("allow_overlap") is None
else self.preproc_params["allow_overlap"]
)
smoothing_fwhm = (
None
if self.preproc_params.get("smoothing_fwhm") is None
else self.preproc_params["smoothing_fwhm"]
)
standardize = (
False
if self.preproc_params.get("standardize") is None
else self.preproc_params["standardize"]
)
standardize_confounds = (
True
if self.preproc_params.get("standardize_confounds") is None
else self.preproc_params["standardize_confounds"]
)
high_variance_confounds = (
False
if self.preproc_params.get("high_variance_confounds") is None
else self.preproc_params["high_variance_confounds"]
)
detrend = (
False
if self.preproc_params.get("detrend") is None
else self.preproc_params["detrend"]
)
low_pass = (
None
if self.preproc_params.get("low_pass") is None
else self.preproc_params["low_pass"]
)
high_pass = (
None
if self.preproc_params.get("high_pass") is None
else self.preproc_params["high_pass"]
)
t_r = (
None
if self.preproc_params.get("t_r") is None
else self.preproc_params["t_r"]
)
# params for fit_transform
confounds = (
None
if self.preproc_params.get("confounds") is None
else self.preproc_params["confounds"]
)
sample_mask = (
None
if self.preproc_params.get("sample_mask") is None
else self.preproc_params["sample_mask"]
sa = SphereAggregation(
coords=self.coords,
radius=self.radius,
method=self.agg_method,
method_params=self.agg_method_params,
on="BOLD",
)

masker = NiftiSpheresMasker(
coords,
self.radius,
mask_img=mask_img,
allow_overlap=allow_overlap,
smoothing_fwhm=smoothing_fwhm,
standardize=standardize,
standardize_confounds=standardize_confounds,
high_variance_confounds=high_variance_confounds,
detrend=detrend,
low_pass=low_pass,
high_pass=high_pass,
t_r=t_r,
)
# get the 2D timeseries
if confounds is None:
ts = masker.fit_transform(input["data"], sample_mask=sample_mask)
else:
ts = masker.fit_transform(
input["data"], sample_mask=sample_mask, confounds=[confounds]
)[0]
ts = sa.compute(input)

if self.cor_method_params["empirical"]:
cm = ConnectivityMeasure(
cov_estimator=EmpiricalCovariance(), kind=self.cor_method
cov_estimator=EmpiricalCovariance(), # type: ignore
kind=self.cor_method,
)
else:
cm = ConnectivityMeasure(kind=self.cor_method)
out = {}
out["data"] = cm.fit_transform([ts])[0]
out["row_names"] = labels
out["col_names"] = labels
out["data"] = cm.fit_transform([ts["data"]])[0]
# create column names
out["row_names"] = ts["columns"]
out["col_names"] = ts["columns"]
out["kind"] = "tril"
return out

Expand Down
41 changes: 39 additions & 2 deletions junifer/markers/tests/test_functional_connectivity_spheres.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@

# Authors: Amir Omidvarnia <a.omidvarnia@fz-juelich.de>
# Kaustubh R. Patil <k.patil@fz-juelich.de>
# Federico Raimondo <f.raimondo@fz-juelich.de>
# License: AGPL
from pathlib import Path
from numpy.testing import assert_array_almost_equal

from nilearn import datasets, image
from nilearn.connectome import ConnectivityMeasure

from junifer.markers.functional_connectivity_spheres import (
FunctionalConnectivitySpheres,
)
from junifer.markers.sphere_aggregation import SphereAggregation
from junifer.storage import SQLiteFeatureStorage


def test_FunctionalConnectivitySpheres() -> None:
def test_FunctionalConnectivitySpheres(tmp_path: Path) -> None:
"""Test FunctionalConnectivitySpheres."""

# get a dataset
Expand All @@ -21,7 +27,9 @@ def test_FunctionalConnectivitySpheres() -> None:
fc = FunctionalConnectivitySpheres(
coords="DMNBuckner", radius=5.0, cor_method="correlation"
)
out = fc.compute({"data": fmri_img})
all_out = fc.fit_transform({"BOLD": {"data": fmri_img}})

out = all_out["BOLD"]

assert "data" in out
assert "row_names" in out
Expand All @@ -30,3 +38,32 @@ def test_FunctionalConnectivitySpheres() -> None:
assert out["data"].shape[1] == 6
assert len(set(out["row_names"])) == 6
assert len(set(out["col_names"])) == 6

# get the timeseries using sa
sa = SphereAggregation(
coords="DMNBuckner", radius=5.0, method="mean", on="BOLD"
)
ts = sa.compute({"data": fmri_img})

# compare with nilearn

# Check that FC are almost equal
cm = ConnectivityMeasure(kind="correlation")
out_ni = cm.fit_transform([ts["data"]])[0]
assert_array_almost_equal(out_ni, out["data"], decimal=3)

# check correct output
assert fc.get_output_kind(["BOLD"]) == ["matrix"]

uri = tmp_path / "test_fc_atlas.db"
# Single storage, must be the uri
storage = SQLiteFeatureStorage(
uri=uri, single_output=True, upsert="ignore"
)
meta = {
"element": "test",
"version": "0.0.1",
"marker": {"name": "fcname"},
}
input = {"BOLD": {"data": fmri_img}, "meta": meta}
all_out = fc.fit_transform(input, storage=storage)

0 comments on commit 242c5a4

Please sign in to comment.