-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #78 from kaurao/feat/marker_fc_sphere_issue41
feat: add support for spheres-based functional connectivity
- Loading branch information
Showing
2 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
"""Provide base class for functional connectivity using spheres.""" | ||
|
||
# Authors: Amir Omidvarnia <a.omidvarnia@fz-juelich.de> | ||
# Kaustubh R. Patil <k.patil@fz-juelich.de> | ||
# License: AGPL | ||
|
||
from typing import Dict, List, Optional | ||
|
||
from nilearn.connectome import ConnectivityMeasure | ||
from sklearn.covariance import EmpiricalCovariance | ||
|
||
from ..api.decorators import register_marker | ||
from ..utils import logger, raise_error | ||
from .base import BaseMarker | ||
from .sphere_aggregation import SphereAggregation | ||
|
||
|
||
@register_marker | ||
class FunctionalConnectivitySpheres(BaseMarker): | ||
"""Class for functional connectivity using coordinates (spheres). | ||
Parameters | ||
---------- | ||
coords : str | ||
The name of the coordinates list to use. See | ||
:mod:`junifer.data.coordinates` | ||
radius : float | ||
The radius of the sphere in mm. If None, the signal will be extracted | ||
from a single voxel. See :class:`nilearn.maskers.NiftiSpheresMasker` | ||
for more information. | ||
agg_method : str | ||
The aggregation method to use. | ||
See :func:`junifer.stats.get_aggfunc_by_name` for more information. | ||
agg_method_params : dict, optional | ||
The parameters to pass to the aggregation method. | ||
name : str, optional | ||
The name of the marker. By default, it will use | ||
KIND_FunctionalConnectivitySpheres where KIND is the kind of data it | ||
was applied to (default None). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
coords: str, | ||
radius: float, | ||
agg_method: str = "mean", | ||
agg_method_params: Optional[Dict] = None, | ||
cor_method: str = "covariance", | ||
cor_method_params: Optional[Dict] = None, | ||
name: Optional[str] = None, | ||
) -> None: | ||
"""Initialize the class.""" | ||
self.coords = coords | ||
self.radius = radius | ||
if radius is None or radius <= 0: | ||
raise_error(f"radius should be > 0: provided {radius}") | ||
self.agg_method = agg_method | ||
self.agg_method_params = ( | ||
{} if agg_method_params is None else agg_method_params | ||
) | ||
self.cor_method = cor_method | ||
self.cor_method_params = ( | ||
{} if cor_method_params is None else cor_method_params | ||
) | ||
# default to nilearn behavior | ||
self.cor_method_params["empirical"] = self.cor_method_params.get( | ||
"empirical", False | ||
) | ||
|
||
super().__init__(on=["BOLD"], name=name) | ||
|
||
def get_output_kind(self, input: List[str]) -> List[str]: | ||
"""Get output kind. | ||
Parameters | ||
---------- | ||
input : list of str | ||
The input to the marker. The list must contain the | ||
available Junifer Data dictionary keys. | ||
Returns | ||
------- | ||
list of str | ||
The updated list of output kinds, as storage possibilities. | ||
""" | ||
outputs = ["matrix"] | ||
return outputs | ||
|
||
def compute(self, input: Dict, extra_input: Optional[Dict] = None) -> Dict: | ||
"""Compute. | ||
Parameters | ||
---------- | ||
input : dict[str, dict] | ||
A single input from the pipeline data object in which to compute | ||
the marker. | ||
extra_input : dict, optional | ||
The other fields in the pipeline data object. Useful for accessing | ||
other data kind that needs to be used in the computation. For | ||
example, the functional connectivity markers can make use of the | ||
confounds if available (default None). | ||
Returns | ||
------- | ||
dict | ||
The computed result as dictionary. The following keys will be | ||
included in the dictionary: | ||
- 'data': FC matrix as a 2D numpy array. | ||
- 'row_names': Row names as a list. | ||
- 'col_names': Col names as a list. | ||
- 'kind': The kind of matrix (tril, triu or full) | ||
""" | ||
sa = SphereAggregation( | ||
coords=self.coords, | ||
radius=self.radius, | ||
method=self.agg_method, | ||
method_params=self.agg_method_params, | ||
on="BOLD", | ||
) | ||
|
||
ts = sa.compute(input) | ||
|
||
if self.cor_method_params["empirical"]: | ||
cm = ConnectivityMeasure( | ||
cov_estimator=EmpiricalCovariance(), # type: ignore | ||
kind=self.cor_method, | ||
) | ||
else: | ||
cm = ConnectivityMeasure(kind=self.cor_method) | ||
out = {} | ||
out["data"] = cm.fit_transform([ts["data"]])[0] | ||
# create column names | ||
out["row_names"] = ts["columns"] | ||
out["col_names"] = ts["columns"] | ||
out["kind"] = "tril" | ||
return out | ||
|
||
# TODO: complete type annotations | ||
def store(self, kind: str, out: Dict, storage) -> None: | ||
"""Store. | ||
Parameters | ||
---------- | ||
input | ||
out | ||
""" | ||
logger.debug(f"Storing {kind} in {storage}") | ||
storage.store_matrix2d(**out) |
133 changes: 133 additions & 0 deletions
133
junifer/markers/tests/test_functional_connectivity_spheres.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
"""Provide test for functional connectivity spheres.""" | ||
|
||
# Authors: Amir Omidvarnia <a.omidvarnia@fz-juelich.de> | ||
# Kaustubh R. Patil <k.patil@fz-juelich.de> | ||
# Federico Raimondo <f.raimondo@fz-juelich.de> | ||
# License: AGPL | ||
|
||
import pytest | ||
|
||
from pathlib import Path | ||
from numpy.testing import assert_array_almost_equal | ||
|
||
from sklearn.covariance import EmpiricalCovariance | ||
from nilearn import datasets, image | ||
from nilearn.connectome import ConnectivityMeasure | ||
|
||
from junifer.markers.functional_connectivity_spheres import ( | ||
FunctionalConnectivitySpheres, | ||
) | ||
from junifer.markers.sphere_aggregation import SphereAggregation | ||
from junifer.storage import SQLiteFeatureStorage | ||
|
||
|
||
def test_FunctionalConnectivitySpheres(tmp_path: Path) -> None: | ||
"""Test FunctionalConnectivitySpheres. | ||
Parameters | ||
---------- | ||
tmp_path : pathlib.Path | ||
The path to the test directory. | ||
""" | ||
|
||
# get a dataset | ||
ni_data = datasets.fetch_spm_auditory(subject_id="sub001") | ||
fmri_img = image.concat_imgs(ni_data.func) # type: ignore | ||
|
||
fc = FunctionalConnectivitySpheres( | ||
coords="DMNBuckner", radius=5.0, cor_method="correlation" | ||
) | ||
all_out = fc.fit_transform({"BOLD": {"data": fmri_img}}) | ||
|
||
out = all_out["BOLD"] | ||
|
||
assert "data" in out | ||
assert "row_names" in out | ||
assert "col_names" in out | ||
assert out["data"].shape[0] == 6 | ||
assert out["data"].shape[1] == 6 | ||
assert len(set(out["row_names"])) == 6 | ||
assert len(set(out["col_names"])) == 6 | ||
|
||
# get the timeseries using sa | ||
sa = SphereAggregation( | ||
coords="DMNBuckner", radius=5.0, method="mean", on="BOLD" | ||
) | ||
ts = sa.compute({"data": fmri_img}) | ||
|
||
# Check that FC are almost equal when using nileran | ||
cm = ConnectivityMeasure(kind="correlation") | ||
out_ni = cm.fit_transform([ts["data"]])[0] | ||
assert_array_almost_equal(out_ni, out["data"], decimal=3) | ||
|
||
# check correct output | ||
assert fc.get_output_kind(["BOLD"]) == ["matrix"] | ||
|
||
uri = tmp_path / "test_fc_atlas.db" | ||
# Single storage, must be the uri | ||
storage = SQLiteFeatureStorage( | ||
uri=uri, single_output=True, upsert="ignore" | ||
) | ||
meta = { | ||
"element": "test", | ||
"version": "0.0.1", | ||
"marker": {"name": "fcname"}, | ||
} | ||
input = {"BOLD": {"data": fmri_img}, "meta": meta} | ||
all_out = fc.fit_transform(input, storage=storage) | ||
|
||
|
||
def test_FunctionalConnectivitySpheres_empirical(tmp_path: Path) -> None: | ||
"""Test FunctionalConnectivitySpheres with empirical covariance. | ||
Parameters | ||
---------- | ||
tmp_path : pathlib.Path | ||
The path to the test directory. | ||
""" | ||
|
||
# get a dataset | ||
ni_data = datasets.fetch_spm_auditory(subject_id="sub001") | ||
fmri_img = image.concat_imgs(ni_data.func) # type: ignore | ||
|
||
fc = FunctionalConnectivitySpheres( | ||
coords="DMNBuckner", | ||
radius=5.0, | ||
cor_method="correlation", | ||
cor_method_params={"empirical": True}, | ||
) | ||
all_out = fc.fit_transform({"BOLD": {"data": fmri_img}}) | ||
|
||
out = all_out["BOLD"] | ||
|
||
assert "data" in out | ||
assert "row_names" in out | ||
assert "col_names" in out | ||
assert out["data"].shape[0] == 6 | ||
assert out["data"].shape[1] == 6 | ||
assert len(set(out["row_names"])) == 6 | ||
assert len(set(out["col_names"])) == 6 | ||
|
||
# get the timeseries using sa | ||
sa = SphereAggregation( | ||
coords="DMNBuckner", radius=5.0, method="mean", on="BOLD" | ||
) | ||
ts = sa.compute({"data": fmri_img}) | ||
|
||
# Check that FC are almost equal when using nileran | ||
cm = ConnectivityMeasure( | ||
cov_estimator=EmpiricalCovariance(), # type: ignore | ||
kind="correlation" | ||
) | ||
out_ni = cm.fit_transform([ts["data"]])[0] | ||
assert_array_almost_equal(out_ni, out["data"], decimal=3) | ||
|
||
|
||
def test_FunctionalConnectivitySpheres_error() -> None: | ||
"""Test FunctionalConnectivitySpheres errors.""" | ||
with pytest.raises(ValueError, match="radius should be > 0"): | ||
FunctionalConnectivitySpheres( | ||
coords="DMNBuckner", radius=-0.1, cor_method="correlation" | ||
) |