diff --git a/junifer/markers/functional_connectivity_spheres.py b/junifer/markers/functional_connectivity_spheres.py new file mode 100644 index 0000000000..36c401e25f --- /dev/null +++ b/junifer/markers/functional_connectivity_spheres.py @@ -0,0 +1,152 @@ +"""Provide base class for functional connectivity using spheres.""" + +# Authors: Amir Omidvarnia +# Kaustubh R. Patil +# License: AGPL + +from typing import Dict, List, Optional + +from nilearn.connectome import ConnectivityMeasure +from sklearn.covariance import EmpiricalCovariance + +from ..api.decorators import register_marker +from ..utils import logger, raise_error +from .base import BaseMarker +from .sphere_aggregation import SphereAggregation + + +@register_marker +class FunctionalConnectivitySpheres(BaseMarker): + """Class for functional connectivity using coordinates (spheres). + + Parameters + ---------- + coords : str + The name of the coordinates list to use. See + :mod:`junifer.data.coordinates` + radius : float + The radius of the sphere in mm. If None, the signal will be extracted + from a single voxel. See :class:`nilearn.maskers.NiftiSpheresMasker` + for more information. + agg_method : str + The aggregation method to use. + See :func:`junifer.stats.get_aggfunc_by_name` for more information. + agg_method_params : dict, optional + The parameters to pass to the aggregation method. + name : str, optional + The name of the marker. By default, it will use + KIND_FunctionalConnectivitySpheres where KIND is the kind of data it + was applied to (default None). + + """ + + def __init__( + self, + coords: str, + radius: float, + agg_method: str = "mean", + agg_method_params: Optional[Dict] = None, + cor_method: str = "covariance", + cor_method_params: Optional[Dict] = None, + name: Optional[str] = None, + ) -> None: + """Initialize the class.""" + self.coords = coords + self.radius = radius + if radius is None or radius <= 0: + raise_error(f"radius should be > 0: provided {radius}") + self.agg_method = agg_method + self.agg_method_params = ( + {} if agg_method_params is None else agg_method_params + ) + self.cor_method = cor_method + self.cor_method_params = ( + {} if cor_method_params is None else cor_method_params + ) + # default to nilearn behavior + self.cor_method_params["empirical"] = self.cor_method_params.get( + "empirical", False + ) + + super().__init__(on=["BOLD"], name=name) + + def get_output_kind(self, input: List[str]) -> List[str]: + """Get output kind. + + Parameters + ---------- + input : list of str + The input to the marker. The list must contain the + available Junifer Data dictionary keys. + + Returns + ------- + list of str + The updated list of output kinds, as storage possibilities. + + """ + outputs = ["matrix"] + return outputs + + def compute(self, input: Dict, extra_input: Optional[Dict] = None) -> Dict: + """Compute. + + Parameters + ---------- + input : dict[str, dict] + A single input from the pipeline data object in which to compute + the marker. + extra_input : dict, optional + The other fields in the pipeline data object. Useful for accessing + other data kind that needs to be used in the computation. For + example, the functional connectivity markers can make use of the + confounds if available (default None). + + Returns + ------- + dict + The computed result as dictionary. The following keys will be + included in the dictionary: + - 'data': FC matrix as a 2D numpy array. + - 'row_names': Row names as a list. + - 'col_names': Col names as a list. + - 'kind': The kind of matrix (tril, triu or full) + + """ + sa = SphereAggregation( + coords=self.coords, + radius=self.radius, + method=self.agg_method, + method_params=self.agg_method_params, + on="BOLD", + ) + + ts = sa.compute(input) + + if self.cor_method_params["empirical"]: + cm = ConnectivityMeasure( + cov_estimator=EmpiricalCovariance(), # type: ignore + kind=self.cor_method, + ) + else: + cm = ConnectivityMeasure(kind=self.cor_method) + out = {} + out["data"] = cm.fit_transform([ts["data"]])[0] + # create column names + out["row_names"] = ts["columns"] + out["col_names"] = ts["columns"] + out["kind"] = "tril" + return out + + # TODO: complete type annotations + def store(self, kind: str, out: Dict, storage) -> None: + """Store. + + Parameters + ---------- + input + out + + """ + logger.debug(f"Storing {kind} in {storage}") + storage.store_matrix2d(**out) diff --git a/junifer/markers/tests/test_functional_connectivity_spheres.py b/junifer/markers/tests/test_functional_connectivity_spheres.py new file mode 100644 index 0000000000..a58b53927c --- /dev/null +++ b/junifer/markers/tests/test_functional_connectivity_spheres.py @@ -0,0 +1,133 @@ +"""Provide test for functional connectivity spheres.""" + +# Authors: Amir Omidvarnia +# Kaustubh R. Patil +# Federico Raimondo +# License: AGPL + +import pytest + +from pathlib import Path +from numpy.testing import assert_array_almost_equal + +from sklearn.covariance import EmpiricalCovariance +from nilearn import datasets, image +from nilearn.connectome import ConnectivityMeasure + +from junifer.markers.functional_connectivity_spheres import ( + FunctionalConnectivitySpheres, +) +from junifer.markers.sphere_aggregation import SphereAggregation +from junifer.storage import SQLiteFeatureStorage + + +def test_FunctionalConnectivitySpheres(tmp_path: Path) -> None: + """Test FunctionalConnectivitySpheres. + + Parameters + ---------- + tmp_path : pathlib.Path + The path to the test directory. + + """ + + # get a dataset + ni_data = datasets.fetch_spm_auditory(subject_id="sub001") + fmri_img = image.concat_imgs(ni_data.func) # type: ignore + + fc = FunctionalConnectivitySpheres( + coords="DMNBuckner", radius=5.0, cor_method="correlation" + ) + all_out = fc.fit_transform({"BOLD": {"data": fmri_img}}) + + out = all_out["BOLD"] + + assert "data" in out + assert "row_names" in out + assert "col_names" in out + assert out["data"].shape[0] == 6 + assert out["data"].shape[1] == 6 + assert len(set(out["row_names"])) == 6 + assert len(set(out["col_names"])) == 6 + + # get the timeseries using sa + sa = SphereAggregation( + coords="DMNBuckner", radius=5.0, method="mean", on="BOLD" + ) + ts = sa.compute({"data": fmri_img}) + + # Check that FC are almost equal when using nileran + cm = ConnectivityMeasure(kind="correlation") + out_ni = cm.fit_transform([ts["data"]])[0] + assert_array_almost_equal(out_ni, out["data"], decimal=3) + + # check correct output + assert fc.get_output_kind(["BOLD"]) == ["matrix"] + + uri = tmp_path / "test_fc_atlas.db" + # Single storage, must be the uri + storage = SQLiteFeatureStorage( + uri=uri, single_output=True, upsert="ignore" + ) + meta = { + "element": "test", + "version": "0.0.1", + "marker": {"name": "fcname"}, + } + input = {"BOLD": {"data": fmri_img}, "meta": meta} + all_out = fc.fit_transform(input, storage=storage) + + +def test_FunctionalConnectivitySpheres_empirical(tmp_path: Path) -> None: + """Test FunctionalConnectivitySpheres with empirical covariance. + + Parameters + ---------- + tmp_path : pathlib.Path + The path to the test directory. + + """ + + # get a dataset + ni_data = datasets.fetch_spm_auditory(subject_id="sub001") + fmri_img = image.concat_imgs(ni_data.func) # type: ignore + + fc = FunctionalConnectivitySpheres( + coords="DMNBuckner", + radius=5.0, + cor_method="correlation", + cor_method_params={"empirical": True}, + ) + all_out = fc.fit_transform({"BOLD": {"data": fmri_img}}) + + out = all_out["BOLD"] + + assert "data" in out + assert "row_names" in out + assert "col_names" in out + assert out["data"].shape[0] == 6 + assert out["data"].shape[1] == 6 + assert len(set(out["row_names"])) == 6 + assert len(set(out["col_names"])) == 6 + + # get the timeseries using sa + sa = SphereAggregation( + coords="DMNBuckner", radius=5.0, method="mean", on="BOLD" + ) + ts = sa.compute({"data": fmri_img}) + + # Check that FC are almost equal when using nileran + cm = ConnectivityMeasure( + cov_estimator=EmpiricalCovariance(), # type: ignore + kind="correlation" + ) + out_ni = cm.fit_transform([ts["data"]])[0] + assert_array_almost_equal(out_ni, out["data"], decimal=3) + + +def test_FunctionalConnectivitySpheres_error() -> None: + """Test FunctionalConnectivitySpheres errors.""" + with pytest.raises(ValueError, match="radius should be > 0"): + FunctionalConnectivitySpheres( + coords="DMNBuckner", radius=-0.1, cor_method="correlation" + )