Skip to content

Commit

Permalink
Merge pull request #265 from juaml/feat/get-coordinates
Browse files Browse the repository at this point in the history
[ENH]: Introduce `junifer.data.get_coordinates()`
  • Loading branch information
synchon authored Oct 18, 2023
2 parents ef8da66 + a1b8951 commit 9689b02
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/265.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introduce :func:`.get_coordinates` to fetch coordinates tailored for the data by `Synchon Mandal`_
1 change: 1 addition & 0 deletions junifer/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
list_coordinates,
load_coordinates,
register_coordinates,
get_coordinates,
)
from .parcellations import (
list_parcellations,
Expand Down
79 changes: 72 additions & 7 deletions junifer/data/coordinates.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Provide functions for list of coordinates."""

# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
# Synchon Mandal <s.mandal@fz-juelich.de>
# License: AGPL

import typing
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -57,7 +58,7 @@ def register_coordinates(
voi_names: List[str],
overwrite: Optional[bool] = False,
) -> None:
"""Register coordinates.
"""Register a custom user coordinates.
Parameters
----------
Expand All @@ -73,6 +74,18 @@ def register_coordinates(
overwrite : bool, optional
If True, overwrite an existing list of coordinates with the same name.
Does not apply to built-in coordinates (default False).
Raises
------
ValueError
If the coordinates name is already registered and overwrite is set to
False or if the coordinates name is a built-in coordinates or if the
``coordinates`` is not a 2D array or if coordinate value does not have
3 components or if the ``voi_names`` shape does not match the
``coordinates`` shape.
TypeError
If ``coordinates`` is not a ``numpy.ndarray``.
"""
if name in _available_coordinates:
if isinstance(_available_coordinates[name], Path):
Expand All @@ -90,7 +103,8 @@ def register_coordinates(

if not isinstance(coordinates, np.ndarray):
raise_error(
f"Coordinates must be a numpy.ndarray, not {type(coordinates)}."
f"Coordinates must be a `numpy.ndarray`, not {type(coordinates)}.",
klass=TypeError,
)
if coordinates.ndim != 2:
raise_error(
Expand All @@ -102,8 +116,8 @@ def register_coordinates(
)
if len(voi_names) != coordinates.shape[0]:
raise_error(
f"Length of voi_names ({len(voi_names)}) does not match the "
f"number of coordinates ({coordinates.shape[0]})."
f"Length of `voi_names` ({len(voi_names)}) does not match the "
f"number of `coordinates` ({coordinates.shape[0]})."
)
_available_coordinates[name] = {
"coords": coordinates,
Expand All @@ -112,16 +126,55 @@ def register_coordinates(


def list_coordinates() -> List[str]:
"""List all the available coordinates lists (VOIs).
"""List all the available coordinates (VOIs).
Returns
-------
list of str
A list with all available coordinates names.
"""
return sorted(_available_coordinates.keys())


def get_coordinates(
coords: str,
target_data: Dict[str, Any],
extra_input: Optional[Dict[str, Any]] = None,
) -> Tuple[ArrayLike, List[str]]:
"""Get coordinates, tailored for the target image.
Parameters
----------
coords : str
The name of the coordinates.
target_data : dict
The corresponding item of the data object to which the coordinates
will be applied.
extra_input : dict, optional
The other fields in the data object. Useful for accessing other data
kinds that needs to be used in the computation of coordinates
(default None).
Returns
-------
numpy.ndarray
The coordinates.
list of str
The names of the VOIs.
Raises
------
ValueError
If ``extra_input`` is None when ``target_data``'s space is not MNI.
"""
# Load the coordinates
seeds, labels = load_coordinates(name=coords)

return seeds, labels


def load_coordinates(name: str) -> Tuple[ArrayLike, List[str]]:
"""Load coordinates.
Expand All @@ -137,14 +190,23 @@ def load_coordinates(name: str) -> Tuple[ArrayLike, List[str]]:
list of str
The names of the VOIs.
Raises
------
ValueError
If ``name`` is invalid.
Warns
-----
DeprecationWarning
If ``Power`` is provided as the ``name``.
"""
# Check for valid coordinates name
if name not in _available_coordinates:
raise_error(f"Coordinates {name} not found.")
raise_error(
f"Coordinates {name} not found. "
f"Valid options are: {list_coordinates()}"
)

# Put up deprecation notice
if name == "Power":
Expand All @@ -157,8 +219,10 @@ def load_coordinates(name: str) -> Tuple[ArrayLike, List[str]]:
category=DeprecationWarning,
)

# Load coordinates
t_coord = _available_coordinates[name]
if isinstance(t_coord, Path):
# Load via pandas
df_coords = pd.read_csv(t_coord, sep="\t", header=None)
coords = df_coords.iloc[:, [0, 1, 2]].to_numpy()
names = list(df_coords.iloc[:, [3]].values[:, 0])
Expand All @@ -167,4 +231,5 @@ def load_coordinates(name: str) -> Tuple[ArrayLike, List[str]]:
coords = typing.cast(ArrayLike, coords)
names = t_coord["voi_names"]
names = typing.cast(List[str], names)

return coords, names
23 changes: 22 additions & 1 deletion junifer/data/tests/test_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from numpy.testing import assert_array_equal

from junifer.data.coordinates import (
get_coordinates,
list_coordinates,
load_coordinates,
register_coordinates,
)
from junifer.datareader import DefaultDataReader
from junifer.testing.datagrabbers import OasisVBMTestingDataGrabber


def test_register_coordinates_built_in_check() -> None:
Expand Down Expand Up @@ -54,7 +57,7 @@ def test_register_coordinates_overwrite() -> None:

def test_register_coordinates_valid_input() -> None:
"""Test coordinates registration check for valid input."""
with pytest.raises(ValueError, match=r"numpy.ndarray"):
with pytest.raises(TypeError, match=r"numpy.ndarray"):
register_coordinates(
name="MyList",
coordinates=[1, 2],
Expand Down Expand Up @@ -105,3 +108,21 @@ def test_load_coordinates_nonexisting() -> None:
"""Test loading coordinates that not exist."""
with pytest.raises(ValueError, match=r"not found"):
load_coordinates("NonExisting")


def test_get_coordinates() -> None:
"""Test tailored coordinates fetch."""
reader = DefaultDataReader()
with OasisVBMTestingDataGrabber() as dg:
element = dg["sub-01"]
element_data = reader.fit_transform(element)
vbm_gm = element_data["VBM_GM"]
# Get tailored coordinates
tailored_coords, tailored_labels = get_coordinates(
coords="DMNBuckner", target_data=vbm_gm
)
# Get raw coordinates
raw_coords, raw_labels = load_coordinates("DMNBuckner")
# Both tailored and raw should be same for now
assert_array_equal(tailored_coords, raw_coords)
assert tailored_labels == raw_labels
38 changes: 33 additions & 5 deletions junifer/markers/sphere_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, ClassVar, Dict, List, Optional, Set, Union

from ..api.decorators import register_marker
from ..data import get_mask, load_coordinates
from ..data import get_coordinates, get_mask
from ..external.nilearn import JuniferNiftiSpheresMasker
from ..stats import get_aggfunc_by_name
from ..utils import logger, raise_error, warn_with_log
Expand Down Expand Up @@ -55,6 +55,12 @@ class SphereAggregation(BaseMarker):
The name of the marker. By default, it will use KIND_SphereAggregation
where KIND is the kind of data it was applied to (default None).
Raises
------
ValueError
If ``time_method`` is specified for non-BOLD data or if
``time_method_params`` is not None when ``time_method`` is None.
"""

_DEPENDENCIES: ClassVar[Set[str]] = {"nilearn", "numpy"}
Expand Down Expand Up @@ -118,14 +124,19 @@ def get_output_type(self, input_type: str) -> str:
str
The storage type output by the marker.
Raises
------
ValueError
If the ``input_type`` is invalid.
"""

if input_type in ["VBM_GM", "VBM_WM", "fALFF", "GCOR", "LCOR"]:
return "vector"
elif input_type == "BOLD":
return "timeseries"
else:
raise ValueError(f"Unknown input kind for {input_type}")
raise_error(f"Unknown input kind for {input_type}")

def compute(
self,
Expand Down Expand Up @@ -155,22 +166,37 @@ def compute(
* ``data`` : the actual computed values as a numpy.ndarray
* ``col_names`` : the column labels for the computed values as list
Warns
-----
RuntimeWarning
If time aggregation is required but only time point is available.
"""
t_input_img = input["data"]
logger.debug(f"Sphere aggregation using {self.method}")
# Get aggregation function
agg_func = get_aggfunc_by_name(
self.method, func_params=self.method_params
)

# Get seeds and labels tailored to target image
coords, labels = get_coordinates(
coords=self.coords,
target_data=input,
extra_input=extra_input,
)

# Load mask
mask_img = None
if self.masks is not None:
logger.debug(f"Masking with {self.masks}")
# Get tailored mask
mask_img = get_mask(
masks=self.masks, target_data=input, extra_input=extra_input
)
# Get seeds and labels
coords, out_labels = load_coordinates(name=self.coords)

# Initialize masker
logger.debug("Masking")
masker = JuniferNiftiSpheresMasker(
seeds=coords,
radius=self.radius,
Expand All @@ -180,6 +206,8 @@ def compute(
)
# Fit and transform the marker on the data
out_values = masker.fit_transform(t_input_img)

# Apply time dimension aggregation if required
if self.time_method is not None:
if out_values.shape[0] > 1:
logger.debug("Aggregating time dimension")
Expand All @@ -193,5 +221,5 @@ def compute(
"available."
)
# Format the output
out = {"data": out_values, "col_names": out_labels}
out = {"data": out_values, "col_names": labels}
return out

0 comments on commit 9689b02

Please sign in to comment.