Skip to content

Commit

Permalink
Merge branch 'main' into update/add-ants-warping
Browse files Browse the repository at this point in the history
  • Loading branch information
fraimondo authored Jan 12, 2024
2 parents c283f14 + 4d9e509 commit 9c0b219
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/287.enh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``mode`` as an aggregation function option in :func:`.get_aggfunc_by_name` by `Synchon Mandal`_
2 changes: 2 additions & 0 deletions junifer/api/tests/test_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_get_dependency_information_short() -> None:
assert list(dependency_information.keys()) == [
"click",
"numpy",
"scipy",
"datalad",
"pandas",
"nibabel",
Expand All @@ -51,6 +52,7 @@ def test_get_dependency_information_long() -> None:
for key in [
"click",
"numpy",
"scipy",
"datalad",
"pandas",
"nibabel",
Expand Down
22 changes: 18 additions & 4 deletions junifer/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Dict, List, Optional

import numpy as np
from scipy.stats import trim_mean
from scipy.stats import mode, trim_mean
from scipy.stats.mstats import winsorize

from .utils import logger, raise_error
Expand All @@ -24,10 +24,11 @@ def get_aggfunc_by_name(
Name to identify the function. Currently supported names and
corresponding functions are:
* ``winsorized_mean`` -> :func:`scipy.stats.mstats.winsorize`
* ``mean`` -> :func:`numpy.mean`
* ``std`` -> :func:`numpy.std`
* ``winsorized_mean`` -> :func:`scipy.stats.mstats.winsorize`
* ``trim_mean`` -> :func:`scipy.stats.trim_mean`
* ``mode`` -> :func:`scipy.stats.mode`
* ``std`` -> :func:`numpy.std`
* ``count`` -> :func:`.count`
* ``select`` -> :func:`.select`
Expand All @@ -40,6 +41,7 @@ def get_aggfunc_by_name(
-------
function
Respective function with ``func_params`` parameter set.
"""
from functools import partial # local import to avoid sphinx error

Expand All @@ -51,6 +53,7 @@ def get_aggfunc_by_name(
"trim_mean",
"count",
"select",
"mode",
}
if func_params is None:
func_params = {}
Expand Down Expand Up @@ -93,6 +96,8 @@ def get_aggfunc_by_name(
elif pick is not None and drop is not None:
raise_error("Either pick or drop must be specified, not both.")
func = partial(select, **func_params)
elif name == "mode":
func = partial(mode, **func_params)
else:
raise_error(
f"Function {name} unknown. Please provide any of "
Expand All @@ -115,6 +120,7 @@ def count(data: np.ndarray, axis: int = 0) -> np.ndarray:
-------
numpy.ndarray
Number of elements along the given axis.
"""
ax_size = data.shape[axis]
if axis < 0:
Expand All @@ -137,7 +143,7 @@ def winsorized_mean(
The axis to calculate winsorized mean on (default None).
**win_params : dict
Dictionary containing the keyword arguments for the winsorize function.
E.g. ``{'limits': [0.1, 0.1]}``.
E.g., ``{'limits': [0.1, 0.1]}``.
Returns
-------
Expand All @@ -149,6 +155,7 @@ def winsorized_mean(
--------
scipy.stats.mstats.winsorize :
The winsorize function used in this function.
"""
win_dat = winsorize(data, axis=axis, **win_params)
win_mean = win_dat.mean(axis=axis)
Expand Down Expand Up @@ -180,6 +187,13 @@ def select(
numpy.ndarray
Subset of the inputted data with the select settings
applied as specified in ``select_params``.
Raises
------
ValueError
If both ``pick`` and ``drop`` are None or
if both ``pick`` and ``drop`` are not None.
"""

if pick is None and drop is None:
Expand Down
2 changes: 2 additions & 0 deletions junifer/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
("count", None),
("trim_mean", None),
("trim_mean", {"proportiontocut": 0.1}),
("mode", None),
("mode", {"keepdims": True}),
],
)
def test_get_aggfunc_by_name(name: str, params: Optional[Dict]) -> None:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ classifiers = [
dependencies = [
"click>=8.1.3,<8.2",
"numpy>=1.24,<1.27",
"scipy>=1.9.0,<=1.11.4",
"datalad>=0.15.4,<0.20",
"pandas>=1.4.0,<2.2",
"nibabel>=3.2.0,<5.11",
Expand Down Expand Up @@ -177,6 +178,7 @@ known-first-party = ["junifer"]
known-third-party =[
"click",
"numpy",
"scipy",
"datalad",
"pandas",
"nibabel",
Expand Down

0 comments on commit 9c0b219

Please sign in to comment.