From c9752210c6ed81089117a0c151375bae6e696090 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 27 Sep 2023 10:37:43 +0200 Subject: [PATCH 01/13] update: introduce data.coordinates.get_coordinates() to fetch tailored fit coordinates --- junifer/data/__init__.py | 1 + junifer/data/coordinates.py | 40 ++++++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/junifer/data/__init__.py b/junifer/data/__init__.py index 65c6e90107..e879179bc3 100644 --- a/junifer/data/__init__.py +++ b/junifer/data/__init__.py @@ -8,6 +8,7 @@ list_coordinates, load_coordinates, register_coordinates, + get_coordinates, ) from .parcellations import ( list_parcellations, diff --git a/junifer/data/coordinates.py b/junifer/data/coordinates.py index 8fe242392f..786252195e 100644 --- a/junifer/data/coordinates.py +++ b/junifer/data/coordinates.py @@ -5,7 +5,7 @@ import typing from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -122,6 +122,44 @@ def list_coordinates() -> List[str]: return sorted(_available_coordinates.keys()) +def get_coordinates( + coords: str, + target_data: Dict[str, Any], + extra_input: Optional[Dict[str, Any]] = None, +) -> Tuple[ArrayLike, List[str]]: + """Get coordinates, tailored for the target image. + + Parameters + ---------- + coords : str + The name of the coordinates. + target_data : dict + The corresponding item of the data object to which the coordinates + will be applied. + extra_input : dict, optional + The other fields in the data object. Useful for accessing other data + kinds that needs to be used in the computation of coordinates + (default None). + + Returns + ------- + numpy.ndarray + The coordinates. + list of str + The names of the VOIs. + + Raises + ------ + ValueError + If ``extra_input`` is None when ``target_data``'s space is not MNI. + + """ + # Load the coordinates + seeds, labels, _ = load_coordinates(name=coords) + + return seeds, labels + + def load_coordinates(name: str) -> Tuple[ArrayLike, List[str]]: """Load coordinates. From cb0d9b2f79f2a16924cc14ccabe5593af7edc770 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 27 Sep 2023 10:43:50 +0200 Subject: [PATCH 02/13] update: replace load_coordinates() with get_coordinates() in SphereAggregation --- junifer/markers/sphere_aggregation.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/junifer/markers/sphere_aggregation.py b/junifer/markers/sphere_aggregation.py index a5d9d5dd97..835ca3b671 100644 --- a/junifer/markers/sphere_aggregation.py +++ b/junifer/markers/sphere_aggregation.py @@ -7,7 +7,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Set, Union from ..api.decorators import register_marker -from ..data import get_mask, load_coordinates +from ..data import get_coordinates, get_mask from ..external.nilearn import JuniferNiftiSpheresMasker from ..stats import get_aggfunc_by_name from ..utils import logger, raise_error, warn_with_log @@ -162,6 +162,14 @@ def compute( agg_func = get_aggfunc_by_name( self.method, func_params=self.method_params ) + + # Get seeds and labels tailored to target image + coords, labels = get_coordinates( + coords=self.coords, + target_data=input, + extra_input=extra_input, + ) + # Load mask mask_img = None if self.masks is not None: @@ -169,8 +177,6 @@ def compute( mask_img = get_mask( masks=self.masks, target_data=input, extra_input=extra_input ) - # Get seeds and labels - coords, out_labels = load_coordinates(name=self.coords) masker = JuniferNiftiSpheresMasker( seeds=coords, radius=self.radius, @@ -193,5 +199,5 @@ def compute( "available." ) # Format the output - out = {"data": out_values, "col_names": out_labels} + out = {"data": out_values, "col_names": labels} return out From 8f749d6448841ea64dc4bcb5b68dbb33d9c1dbeb Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 27 Sep 2023 10:38:56 +0200 Subject: [PATCH 03/13] update: add tests for data.coordinates.get_coordinates() --- junifer/data/tests/test_coordinates.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/junifer/data/tests/test_coordinates.py b/junifer/data/tests/test_coordinates.py index 630018b058..c20f7f4437 100644 --- a/junifer/data/tests/test_coordinates.py +++ b/junifer/data/tests/test_coordinates.py @@ -8,10 +8,13 @@ from numpy.testing import assert_array_equal from junifer.data.coordinates import ( + get_coordinates, list_coordinates, load_coordinates, register_coordinates, ) +from junifer.datareader import DefaultDataReader +from junifer.testing.datagrabbers import OasisVBMTestingDataGrabber def test_register_coordinates_built_in_check() -> None: @@ -105,3 +108,21 @@ def test_load_coordinates_nonexisting() -> None: """Test loading coordinates that not exist.""" with pytest.raises(ValueError, match=r"not found"): load_coordinates("NonExisting") + + +def test_get_coordinates() -> None: + """Test tailored coordinates fetch.""" + reader = DefaultDataReader() + with OasisVBMTestingDataGrabber() as dg: + element = dg["sub-01"] + element_data = reader.fit_transform(element) + vbm_gm = element_data["VBM_GM"] + # Get tailored coordinates + tailored_coords, tailored_labels = get_coordinates( + coords="DMNBuckner", target_data=vbm_gm + ) + # Get raw coordinates + raw_coords, raw_labels, _ = load_coordinates("DMNBuckner") + # Both tailored and raw should be same for now + assert_array_equal(tailored_coords, raw_coords) + assert tailored_labels == raw_labels From 3a25a02a93a9398d436afdab04281a542d5c4235 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 20 Sep 2023 15:11:26 +0200 Subject: [PATCH 04/13] chore: improve docstrings in data.coordinates --- junifer/data/coordinates.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/junifer/data/coordinates.py b/junifer/data/coordinates.py index 786252195e..d87a73ada0 100644 --- a/junifer/data/coordinates.py +++ b/junifer/data/coordinates.py @@ -57,7 +57,7 @@ def register_coordinates( voi_names: List[str], overwrite: Optional[bool] = False, ) -> None: - """Register coordinates. + """Register a custom user coordinates. Parameters ---------- @@ -73,6 +73,18 @@ def register_coordinates( overwrite : bool, optional If True, overwrite an existing list of coordinates with the same name. Does not apply to built-in coordinates (default False). + + Raises + ------ + ValueError + If the coordinates name is already registered and overwrite is set to + False or if the coordinates name is a built-in coordinates or if the + ``coordinates`` is not a 2D array or if coordinate value does not have + 3 components or if the ``voi_names`` shape does not match the + ``coordinates`` shape. + TypeError + If ``coordinates`` is not a ``numpy.ndarray``. + """ if name in _available_coordinates: if isinstance(_available_coordinates[name], Path): @@ -112,12 +124,13 @@ def register_coordinates( def list_coordinates() -> List[str]: - """List all the available coordinates lists (VOIs). + """List all the available coordinates (VOIs). Returns ------- list of str A list with all available coordinates names. + """ return sorted(_available_coordinates.keys()) @@ -175,6 +188,11 @@ def load_coordinates(name: str) -> Tuple[ArrayLike, List[str]]: list of str The names of the VOIs. + Raises + ------ + ValueError + If ``name`` is invalid. + Warns ----- DeprecationWarning From 0102d53d8633bfeeccf58da8dace3d3f94dc172c Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 20 Sep 2023 15:12:00 +0200 Subject: [PATCH 05/13] chore: improve logging messages in data.coordinates --- junifer/data/coordinates.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/junifer/data/coordinates.py b/junifer/data/coordinates.py index d87a73ada0..e881d4174f 100644 --- a/junifer/data/coordinates.py +++ b/junifer/data/coordinates.py @@ -1,6 +1,7 @@ """Provide functions for list of coordinates.""" # Authors: Federico Raimondo +# Synchon Mandal # License: AGPL import typing @@ -102,7 +103,8 @@ def register_coordinates( if not isinstance(coordinates, np.ndarray): raise_error( - f"Coordinates must be a numpy.ndarray, not {type(coordinates)}." + f"Coordinates must be a `numpy.ndarray`, not {type(coordinates)}.", + klass=TypeError, ) if coordinates.ndim != 2: raise_error( @@ -114,8 +116,8 @@ def register_coordinates( ) if len(voi_names) != coordinates.shape[0]: raise_error( - f"Length of voi_names ({len(voi_names)}) does not match the " - f"number of coordinates ({coordinates.shape[0]})." + f"Length of `voi_names` ({len(voi_names)}) does not match the " + f"number of `coordinates` ({coordinates.shape[0]})." ) _available_coordinates[name] = { "coords": coordinates, @@ -200,7 +202,10 @@ def load_coordinates(name: str) -> Tuple[ArrayLike, List[str]]: """ if name not in _available_coordinates: - raise_error(f"Coordinates {name} not found.") + raise_error( + f"Coordinates {name} not found. " + f"Valid options are: {list_coordinates()}" + ) # Put up deprecation notice if name == "Power": From 5b19c860b86e38bac78c0ba593f349e0788c3938 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 20 Sep 2023 15:12:23 +0200 Subject: [PATCH 06/13] chore: improve commentary in data.coordinates --- junifer/data/coordinates.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/junifer/data/coordinates.py b/junifer/data/coordinates.py index e881d4174f..078cd7ae65 100644 --- a/junifer/data/coordinates.py +++ b/junifer/data/coordinates.py @@ -201,6 +201,7 @@ def load_coordinates(name: str) -> Tuple[ArrayLike, List[str]]: If ``Power`` is provided as the ``name``. """ + # Check for valid coordinates name if name not in _available_coordinates: raise_error( f"Coordinates {name} not found. " @@ -218,8 +219,10 @@ def load_coordinates(name: str) -> Tuple[ArrayLike, List[str]]: category=DeprecationWarning, ) + # Load coordinates t_coord = _available_coordinates[name] if isinstance(t_coord, Path): + # Load via pandas df_coords = pd.read_csv(t_coord, sep="\t", header=None) coords = df_coords.iloc[:, [0, 1, 2]].to_numpy() names = list(df_coords.iloc[:, [3]].values[:, 0]) @@ -228,4 +231,5 @@ def load_coordinates(name: str) -> Tuple[ArrayLike, List[str]]: coords = typing.cast(ArrayLike, coords) names = t_coord["voi_names"] names = typing.cast(List[str], names) + return coords, names From 979186231e3e96375ed07ac7cee6daabdb4a4832 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 20 Sep 2023 15:13:23 +0200 Subject: [PATCH 07/13] chore: improve docstrings for SphereAggregation --- junifer/markers/sphere_aggregation.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/junifer/markers/sphere_aggregation.py b/junifer/markers/sphere_aggregation.py index 835ca3b671..ad90ecc773 100644 --- a/junifer/markers/sphere_aggregation.py +++ b/junifer/markers/sphere_aggregation.py @@ -55,6 +55,12 @@ class SphereAggregation(BaseMarker): The name of the marker. By default, it will use KIND_SphereAggregation where KIND is the kind of data it was applied to (default None). + Raises + ------ + ValueError + If ``time_method`` is specified for non-BOLD data or if + ``time_method_params`` is not None when ``time_method`` is None. + """ _DEPENDENCIES: ClassVar[Set[str]] = {"nilearn", "numpy"} @@ -118,6 +124,11 @@ def get_output_type(self, input_type: str) -> str: str The storage type output by the marker. + Raises + ------ + ValueError + If the ``input_type`` is invalid. + """ if input_type in ["VBM_GM", "VBM_WM", "fALFF", "GCOR", "LCOR"]: @@ -155,6 +166,11 @@ def compute( * ``data`` : the actual computed values as a numpy.ndarray * ``col_names`` : the column labels for the computed values as list + Warns + ----- + RuntimeWarning + If time aggregation is required but only time point is available. + """ t_input_img = input["data"] logger.debug(f"Sphere aggregation using {self.method}") From fc9b29e5beac716462a8cc530e6a80605cfc64f6 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 20 Sep 2023 15:13:47 +0200 Subject: [PATCH 08/13] chore: improve error reporting for SphereAggregation --- junifer/markers/sphere_aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/markers/sphere_aggregation.py b/junifer/markers/sphere_aggregation.py index ad90ecc773..841159e794 100644 --- a/junifer/markers/sphere_aggregation.py +++ b/junifer/markers/sphere_aggregation.py @@ -136,7 +136,7 @@ def get_output_type(self, input_type: str) -> str: elif input_type == "BOLD": return "timeseries" else: - raise ValueError(f"Unknown input kind for {input_type}") + raise_error(f"Unknown input kind for {input_type}") def compute( self, From ae5b7e2ee5332a572e1885f7961bbcf9f2c8592e Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 27 Sep 2023 10:54:02 +0200 Subject: [PATCH 09/13] chore: improve commentary in SphereAggregation --- junifer/markers/sphere_aggregation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/junifer/markers/sphere_aggregation.py b/junifer/markers/sphere_aggregation.py index 841159e794..002165e8ed 100644 --- a/junifer/markers/sphere_aggregation.py +++ b/junifer/markers/sphere_aggregation.py @@ -190,9 +190,12 @@ def compute( mask_img = None if self.masks is not None: logger.debug(f"Masking with {self.masks}") + # Get tailored mask mask_img = get_mask( masks=self.masks, target_data=input, extra_input=extra_input ) + + # Initialize masker masker = JuniferNiftiSpheresMasker( seeds=coords, radius=self.radius, @@ -202,6 +205,8 @@ def compute( ) # Fit and transform the marker on the data out_values = masker.fit_transform(t_input_img) + + # Apply time dimension aggregation if required if self.time_method is not None: if out_values.shape[0] > 1: logger.debug("Aggregating time dimension") From 0eba3e146224ca1fddb78d693fb0cb29ddbfa274 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 27 Sep 2023 10:54:17 +0200 Subject: [PATCH 10/13] chore: improve logging for SphereAggregation --- junifer/markers/sphere_aggregation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/junifer/markers/sphere_aggregation.py b/junifer/markers/sphere_aggregation.py index 002165e8ed..e8eec9c737 100644 --- a/junifer/markers/sphere_aggregation.py +++ b/junifer/markers/sphere_aggregation.py @@ -196,6 +196,7 @@ def compute( ) # Initialize masker + logger.debug("Masking") masker = JuniferNiftiSpheresMasker( seeds=coords, radius=self.radius, From 3e3dec7c604cac872881886fe3b11240445e6e5e Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 17 Oct 2023 17:35:09 +0200 Subject: [PATCH 11/13] chore: add changelog 265.feature --- docs/changes/newsfragments/265.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/changes/newsfragments/265.feature diff --git a/docs/changes/newsfragments/265.feature b/docs/changes/newsfragments/265.feature new file mode 100644 index 0000000000..eb3416c85f --- /dev/null +++ b/docs/changes/newsfragments/265.feature @@ -0,0 +1 @@ +Introduce :func:`.get_coordinates` to fetch coordinates tailored for the data by `Synchon Mandal`_ From da3b6cb39d31286e0633e3676078c9e5de779220 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 18 Oct 2023 13:29:24 +0200 Subject: [PATCH 12/13] fix: remove extra return value from load_coordinates() in get_coordinates() --- junifer/data/coordinates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/data/coordinates.py b/junifer/data/coordinates.py index 078cd7ae65..c9acc55e92 100644 --- a/junifer/data/coordinates.py +++ b/junifer/data/coordinates.py @@ -170,7 +170,7 @@ def get_coordinates( """ # Load the coordinates - seeds, labels, _ = load_coordinates(name=coords) + seeds, labels = load_coordinates(name=coords) return seeds, labels From a1b8951f665af7c33e56a97da3d54884841888db Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 18 Oct 2023 13:29:50 +0200 Subject: [PATCH 13/13] fix: test fixes in test_coordinates.py --- junifer/data/tests/test_coordinates.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/junifer/data/tests/test_coordinates.py b/junifer/data/tests/test_coordinates.py index c20f7f4437..f762c2f79c 100644 --- a/junifer/data/tests/test_coordinates.py +++ b/junifer/data/tests/test_coordinates.py @@ -57,7 +57,7 @@ def test_register_coordinates_overwrite() -> None: def test_register_coordinates_valid_input() -> None: """Test coordinates registration check for valid input.""" - with pytest.raises(ValueError, match=r"numpy.ndarray"): + with pytest.raises(TypeError, match=r"numpy.ndarray"): register_coordinates( name="MyList", coordinates=[1, 2], @@ -122,7 +122,7 @@ def test_get_coordinates() -> None: coords="DMNBuckner", target_data=vbm_gm ) # Get raw coordinates - raw_coords, raw_labels, _ = load_coordinates("DMNBuckner") + raw_coords, raw_labels = load_coordinates("DMNBuckner") # Both tailored and raw should be same for now assert_array_equal(tailored_coords, raw_coords) assert tailored_labels == raw_labels