Skip to content

Commit

Permalink
Merge pull request #78 from kaurao/feat/marker_fc_sphere_issue41
Browse files Browse the repository at this point in the history
feat: add support for spheres-based functional connectivity
  • Loading branch information
synchon authored Oct 7, 2022
2 parents 91fc853 + d8e3c83 commit 2bc117a
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 0 deletions.
152 changes: 152 additions & 0 deletions junifer/markers/functional_connectivity_spheres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Provide base class for functional connectivity using spheres."""

# Authors: Amir Omidvarnia <a.omidvarnia@fz-juelich.de>
# Kaustubh R. Patil <k.patil@fz-juelich.de>
# License: AGPL

from typing import Dict, List, Optional

from nilearn.connectome import ConnectivityMeasure
from sklearn.covariance import EmpiricalCovariance

from ..api.decorators import register_marker
from ..utils import logger, raise_error
from .base import BaseMarker
from .sphere_aggregation import SphereAggregation


@register_marker
class FunctionalConnectivitySpheres(BaseMarker):
"""Class for functional connectivity using coordinates (spheres).
Parameters
----------
coords : str
The name of the coordinates list to use. See
:mod:`junifer.data.coordinates`
radius : float
The radius of the sphere in mm. If None, the signal will be extracted
from a single voxel. See :class:`nilearn.maskers.NiftiSpheresMasker`
for more information.
agg_method : str
The aggregation method to use.
See :func:`junifer.stats.get_aggfunc_by_name` for more information.
agg_method_params : dict, optional
The parameters to pass to the aggregation method.
name : str, optional
The name of the marker. By default, it will use
KIND_FunctionalConnectivitySpheres where KIND is the kind of data it
was applied to (default None).
"""

def __init__(
self,
coords: str,
radius: float,
agg_method: str = "mean",
agg_method_params: Optional[Dict] = None,
cor_method: str = "covariance",
cor_method_params: Optional[Dict] = None,
name: Optional[str] = None,
) -> None:
"""Initialize the class."""
self.coords = coords
self.radius = radius
if radius is None or radius <= 0:
raise_error(f"radius should be > 0: provided {radius}")
self.agg_method = agg_method
self.agg_method_params = (
{} if agg_method_params is None else agg_method_params
)
self.cor_method = cor_method
self.cor_method_params = (
{} if cor_method_params is None else cor_method_params
)
# default to nilearn behavior
self.cor_method_params["empirical"] = self.cor_method_params.get(
"empirical", False
)

super().__init__(on=["BOLD"], name=name)

def get_output_kind(self, input: List[str]) -> List[str]:
"""Get output kind.
Parameters
----------
input : list of str
The input to the marker. The list must contain the
available Junifer Data dictionary keys.
Returns
-------
list of str
The updated list of output kinds, as storage possibilities.
"""
outputs = ["matrix"]
return outputs

def compute(self, input: Dict, extra_input: Optional[Dict] = None) -> Dict:
"""Compute.
Parameters
----------
input : dict[str, dict]
A single input from the pipeline data object in which to compute
the marker.
extra_input : dict, optional
The other fields in the pipeline data object. Useful for accessing
other data kind that needs to be used in the computation. For
example, the functional connectivity markers can make use of the
confounds if available (default None).
Returns
-------
dict
The computed result as dictionary. The following keys will be
included in the dictionary:
- 'data': FC matrix as a 2D numpy array.
- 'row_names': Row names as a list.
- 'col_names': Col names as a list.
- 'kind': The kind of matrix (tril, triu or full)
"""
sa = SphereAggregation(
coords=self.coords,
radius=self.radius,
method=self.agg_method,
method_params=self.agg_method_params,
on="BOLD",
)

ts = sa.compute(input)

if self.cor_method_params["empirical"]:
cm = ConnectivityMeasure(
cov_estimator=EmpiricalCovariance(), # type: ignore
kind=self.cor_method,
)
else:
cm = ConnectivityMeasure(kind=self.cor_method)
out = {}
out["data"] = cm.fit_transform([ts["data"]])[0]
# create column names
out["row_names"] = ts["columns"]
out["col_names"] = ts["columns"]
out["kind"] = "tril"
return out

# TODO: complete type annotations
def store(self, kind: str, out: Dict, storage) -> None:
"""Store.
Parameters
----------
input
out
"""
logger.debug(f"Storing {kind} in {storage}")
storage.store_matrix2d(**out)
133 changes: 133 additions & 0 deletions junifer/markers/tests/test_functional_connectivity_spheres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Provide test for functional connectivity spheres."""

# Authors: Amir Omidvarnia <a.omidvarnia@fz-juelich.de>
# Kaustubh R. Patil <k.patil@fz-juelich.de>
# Federico Raimondo <f.raimondo@fz-juelich.de>
# License: AGPL

import pytest

from pathlib import Path
from numpy.testing import assert_array_almost_equal

from sklearn.covariance import EmpiricalCovariance
from nilearn import datasets, image
from nilearn.connectome import ConnectivityMeasure

from junifer.markers.functional_connectivity_spheres import (
FunctionalConnectivitySpheres,
)
from junifer.markers.sphere_aggregation import SphereAggregation
from junifer.storage import SQLiteFeatureStorage


def test_FunctionalConnectivitySpheres(tmp_path: Path) -> None:
"""Test FunctionalConnectivitySpheres.
Parameters
----------
tmp_path : pathlib.Path
The path to the test directory.
"""

# get a dataset
ni_data = datasets.fetch_spm_auditory(subject_id="sub001")
fmri_img = image.concat_imgs(ni_data.func) # type: ignore

fc = FunctionalConnectivitySpheres(
coords="DMNBuckner", radius=5.0, cor_method="correlation"
)
all_out = fc.fit_transform({"BOLD": {"data": fmri_img}})

out = all_out["BOLD"]

assert "data" in out
assert "row_names" in out
assert "col_names" in out
assert out["data"].shape[0] == 6
assert out["data"].shape[1] == 6
assert len(set(out["row_names"])) == 6
assert len(set(out["col_names"])) == 6

# get the timeseries using sa
sa = SphereAggregation(
coords="DMNBuckner", radius=5.0, method="mean", on="BOLD"
)
ts = sa.compute({"data": fmri_img})

# Check that FC are almost equal when using nileran
cm = ConnectivityMeasure(kind="correlation")
out_ni = cm.fit_transform([ts["data"]])[0]
assert_array_almost_equal(out_ni, out["data"], decimal=3)

# check correct output
assert fc.get_output_kind(["BOLD"]) == ["matrix"]

uri = tmp_path / "test_fc_atlas.db"
# Single storage, must be the uri
storage = SQLiteFeatureStorage(
uri=uri, single_output=True, upsert="ignore"
)
meta = {
"element": "test",
"version": "0.0.1",
"marker": {"name": "fcname"},
}
input = {"BOLD": {"data": fmri_img}, "meta": meta}
all_out = fc.fit_transform(input, storage=storage)


def test_FunctionalConnectivitySpheres_empirical(tmp_path: Path) -> None:
"""Test FunctionalConnectivitySpheres with empirical covariance.
Parameters
----------
tmp_path : pathlib.Path
The path to the test directory.
"""

# get a dataset
ni_data = datasets.fetch_spm_auditory(subject_id="sub001")
fmri_img = image.concat_imgs(ni_data.func) # type: ignore

fc = FunctionalConnectivitySpheres(
coords="DMNBuckner",
radius=5.0,
cor_method="correlation",
cor_method_params={"empirical": True},
)
all_out = fc.fit_transform({"BOLD": {"data": fmri_img}})

out = all_out["BOLD"]

assert "data" in out
assert "row_names" in out
assert "col_names" in out
assert out["data"].shape[0] == 6
assert out["data"].shape[1] == 6
assert len(set(out["row_names"])) == 6
assert len(set(out["col_names"])) == 6

# get the timeseries using sa
sa = SphereAggregation(
coords="DMNBuckner", radius=5.0, method="mean", on="BOLD"
)
ts = sa.compute({"data": fmri_img})

# Check that FC are almost equal when using nileran
cm = ConnectivityMeasure(
cov_estimator=EmpiricalCovariance(), # type: ignore
kind="correlation"
)
out_ni = cm.fit_transform([ts["data"]])[0]
assert_array_almost_equal(out_ni, out["data"], decimal=3)


def test_FunctionalConnectivitySpheres_error() -> None:
"""Test FunctionalConnectivitySpheres errors."""
with pytest.raises(ValueError, match="radius should be > 0"):
FunctionalConnectivitySpheres(
coords="DMNBuckner", radius=-0.1, cor_method="correlation"
)

0 comments on commit 2bc117a

Please sign in to comment.