Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: Add static type checking #28

Merged
merged 29 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f20fe5f
ENH: Add static type checking
jhlegarreta Dec 20, 2024
734862c
ENH: Comment empty `latex_elements` section in documentation config file
jhlegarreta Dec 21, 2024
12de4f9
ENH: Fix arguments to `unique` in GP error analysis plot script
jhlegarreta Dec 21, 2024
b06eff8
ENH: Convert list to array prior to computing mean and std dev
jhlegarreta Dec 21, 2024
3eb17ab
BUG: Use appropriate keyword arg names to instantiate `SphericalKriging`
jhlegarreta Dec 21, 2024
781104a
ENH: Select appropriate element in case `predict` returns a tuple
jhlegarreta Dec 21, 2024
d426757
ENH: Group keyword arguments into a single dictionary
jhlegarreta Dec 21, 2024
bc48dcf
ENH: Use `str` instead of `Path` for `_parse_yaml_config` parameter
jhlegarreta Dec 21, 2024
d3591b4
ENH: Use arrays in NumPy's `percentile` arguments
jhlegarreta Dec 21, 2024
3b214c3
ENH: List `sigma_sq` in the GP model slots
jhlegarreta Dec 21, 2024
590d235
ENH: Add the dimensionality to the `mask` ndarray parameter annotation
jhlegarreta Dec 21, 2024
ecbd574
ENH: Fix type hint for `figsize` parameter
jhlegarreta Dec 21, 2024
73afbf7
ENH: Provide appropriate type hints to `reg_target_type`
jhlegarreta Dec 21, 2024
db1a1e7
ENH: Import `Bounds` from `scipy.optimize`
jhlegarreta Dec 21, 2024
87f7e90
ENH: Avoid type checking for private function import statement
jhlegarreta Dec 21, 2024
b8601d9
ENH: Remove unused `namedtuple` definition in test
jhlegarreta Dec 21, 2024
24e4ec1
ENH: Use `ClassVar` for class variable type hinting
jhlegarreta Dec 21, 2024
cc735d1
ENH: Annotate `optimizer` attribute type DiffusionGPR
jhlegarreta Dec 22, 2024
052cf36
chore: Ignore more untyped dependencies
effigies Jan 22, 2025
374fe48
type: Fix annotations in data.base and utils.iterators
effigies Jan 22, 2025
478c860
type: Fix expected type of set_xticklabels arg
effigies Jan 22, 2025
0037efd
chore: Fix DiffusionGPR annotations
effigies Jan 22, 2025
9f3fdd5
type: Annotate PET Dataset
effigies Jan 22, 2025
05237a6
type: Fix annotations for dMRI model
effigies Jan 22, 2025
4eb80dd
type: Fix complaint about mismatched return type
effigies Jan 22, 2025
e240ac6
type: Ignore warning related to bad upstream stubs
effigies Jan 22, 2025
b450b8d
type: Fixes
effigies Jan 22, 2025
2310945
doc: Unmock numpy to allow np.ndarray to be found as a type
effigies Jan 22, 2025
02ff9c4
type: Annotate `Kernel.diag()` argument `X`
effigies Jan 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ jobs:
continue-on-error: true
strategy:
matrix:
check: ['spellcheck']

check: ['spellcheck', 'typecheck']
steps:
- uses: actions/checkout@v4
- name: Install the latest version of uv
Expand Down
29 changes: 14 additions & 15 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
"nipype",
"nitime",
"nitransforms",
"numpy",
"pandas",
"scipy",
"seaborn",
Expand Down Expand Up @@ -154,20 +153,20 @@

# -- Options for LaTeX output ------------------------------------------------

latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# latex_elements = {
# # The paper size ('letterpaper' or 'a4paper').
# #
# # 'papersize': 'letterpaper',
# # The font size ('10pt', '11pt' or '12pt').
# #
# # 'pointsize': '10pt',
# # Additional stuff for the LaTeX preamble.
# #
# # 'preamble': '',
# # Latex figure (float) alignment
# #
# # 'figure_align': 'htbp',
# }

# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
Expand Down
24 changes: 24 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ test = [
"pytest-env",
"pytest-xdist >= 1.28"
]
types = [
"pandas-stubs",
"types-setuptools",
"scipy-stubs",
"types-PyYAML",
"types-tqdm",
"pytest",
"microsoft-python-type-stubs @ git+https://github.com/microsoft/python-type-stubs.git",
]

notebooks = [
"jupyter",
Expand Down Expand Up @@ -138,6 +147,21 @@ version-file = "src/nifreeze/_version.py"
# Developer tool configurations
#

[[tool.mypy.overrides]]
module = [
"nipype.*",
"nilearn.*",
"nireports.*",
"nitransforms.*",
"seaborn",
"dipy.*",
"smac.*",
"joblib",
"h5py",
"ConfigSpace",
]
ignore_missing_imports = true

[tool.ruff]
line-length = 99
target-version = "py310"
Expand Down
17 changes: 12 additions & 5 deletions scripts/dwi_gp_estimation_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def cross_validate(
cv: int,
n_repeats: int,
gpr: DiffusionGPR,
) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]:
) -> np.ndarray:
"""
Perform the experiment by estimating the dMRI signal using a Gaussian process model.

Expand All @@ -74,7 +74,14 @@ def cross_validate(
"""

rkf = RepeatedKFold(n_splits=cv, n_repeats=n_repeats)
scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf)
# scikit-learn stubs do not recognize rkf as a BaseCrossValidator
scores = cross_val_score(
gpr,
X,
y,
scoring="neg_root_mean_squared_error",
cv=rkf, # type: ignore[arg-type]
)
return scores


Expand Down Expand Up @@ -204,10 +211,10 @@ def main() -> None:

if args.kfold:
# Use Scikit-learn cross validation
scores = defaultdict(list, {})
scores: dict[str, list] = defaultdict(list, {})
for n in args.kfold:
for i in range(args.repeats):
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
cv_scores = -1.0 * cross_validate(X, y.T, n, i, gpr)
scores["rmse"] += cv_scores.tolist()
scores["repeat"] += [i] * len(cv_scores)
scores["n_folds"] += [n] * len(cv_scores)
Expand All @@ -217,7 +224,7 @@ def main() -> None:
print(f"Finished {n}-fold cross-validation")

scores_df = pd.DataFrame(scores)
scores_df.to_csv(args.output_scores, sep="\t", index=None, na_rep="n/a")
scores_df.to_csv(args.output_scores, sep="\t", index=False, na_rep="n/a")

grouped = scores_df.groupby(["n_folds"])
print(grouped[["rmse"]].mean())
Expand Down
16 changes: 12 additions & 4 deletions scripts/dwi_gp_estimation_error_analysis_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,18 @@ def main() -> None:
df = pd.read_csv(args.error_data_fname, sep="\t", keep_default_na=False, na_values="n/a")

# Plot the prediction error
kfolds = sorted(np.unique(df["n_folds"].values))
snr = np.unique(df["snr"].values).item()
bval = np.unique(df["bval"].values).item()
rmse_data = [df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds]
kfolds = sorted(pd.unique(df["n_folds"]))
snr = pd.unique(df["snr"])
if len(snr) == 1:
snr = snr[0]
else:
raise ValueError(f"More than one unique SNR value: {snr}")
bval = pd.unique(df["bval"])
if len(bval) == 1:
bval = bval[0]
else:
raise ValueError(f"More than one unique bval value: {bval}")
rmse_data = np.asarray([df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds])
axis = 1
mean = np.mean(rmse_data, axis=axis)
std_dev = np.std(rmse_data, axis=axis)
Expand Down
8 changes: 5 additions & 3 deletions scripts/dwi_gp_estimation_simulated_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ def main() -> None:

# Fit the Gaussian Process regressor and predict on an arbitrary number of
# directions
a = 1.15
lambda_s = 120
beta_a = 1.15
beta_l = 120
alpha = 100
gpr = DiffusionGPR(
kernel=SphericalKriging(a=a, lambda_s=lambda_s),
kernel=SphericalKriging(beta_a=beta_a, beta_l=beta_l),
alpha=alpha,
optimizer=None,
)
Expand All @@ -154,6 +154,8 @@ def main() -> None:
X_test = np.vstack([gtab[~gtab.b0s_mask].bvecs, sph.vertices])

predictions = gpr_fit.predict(X_test)
if isinstance(predictions, tuple):
predictions = predictions[0]

# Save the predicted data
testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname)
Expand Down
5 changes: 3 additions & 2 deletions scripts/optimize_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,13 @@ async def train_coro(
moving_path = tmp_folder / f"test-{index:04d}.nii.gz"
(~xfm).apply(refnii, reference=refnii).to_filename(moving_path)

_kwargs = {"output_transform_prefix": f"conversion-{index:04d}", **align_kwargs}

cmdline = erants.generate_command(
fixed_path,
moving_path,
fixedmask_path=brainmask_path,
output_transform_prefix=f"conversion-{index:04d}",
**align_kwargs,
**_kwargs,
)

tasks.append(
Expand Down
4 changes: 2 additions & 2 deletions src/nifreeze/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
import yaml


def _parse_yaml_config(file_path: Path) -> dict:
def _parse_yaml_config(file_path: str) -> dict:
"""
Parse YAML configuration file.

Parameters
----------
file_path : Path
file_path : str
Path to the YAML configuration file.

Returns
Expand Down
43 changes: 26 additions & 17 deletions src/nifreeze/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,23 @@
from collections import namedtuple
from pathlib import Path
from tempfile import mkdtemp
from typing import Any
from typing import Any, Generic, TypeVarTuple

import attr
import h5py
import nibabel as nb
import numpy as np
from nibabel.spatialimages import SpatialHeader, SpatialImage
from nitransforms.linear import Affine

from nifreeze.utils.ndimage import load_api

NFDH5_EXT = ".h5"


Ts = TypeVarTuple("Ts")


def _data_repr(value: np.ndarray | None) -> str:
if value is None:
return "None"
Expand All @@ -52,7 +58,7 @@


@attr.s(slots=True)
class BaseDataset:
class BaseDataset(Generic[*Ts]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main new thing I introduced. This BaseDataset class needs to permit a dynamic number of return values for __getitem__, dependent on the subclass. In the future, we should be able to set this as a default:

class BaseDataset[*Ts=Unpack[tuple[()]]]:
    ...

Which, while not particularly pretty, will allow us not to have to type BaseDataset[()] when annotating a variable that is actually this superclass.

"""
Base dataset representation structure.

Expand All @@ -68,15 +74,15 @@

"""

dataobj = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
dataobj: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
"""A :obj:`~numpy.ndarray` object for the data array."""
affine = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
affine: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
"""Best affine for RAS-to-voxel conversion of coordinates (NIfTI header)."""
brainmask = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
brainmask: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
"""A boolean ndarray object containing a corresponding brainmask."""
motion_affines = attr.ib(default=None, eq=attr.cmp_using(eq=_cmp))
motion_affines: np.ndarray = attr.ib(default=None, eq=attr.cmp_using(eq=_cmp))
"""List of :obj:`~nitransforms.linear.Affine` realigning the dataset."""
datahdr = attr.ib(default=None)
datahdr: SpatialHeader = attr.ib(default=None)
"""A :obj:`~nibabel.spatialimages.SpatialHeader` header corresponding to the data."""

_filepath = attr.ib(
Expand All @@ -93,9 +99,13 @@

return self.dataobj.shape[-1]

def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[*Ts]:
# PY312: Default values for TypeVarTuples are not yet supported
return () # type: ignore[return-value]

def __getitem__(
self, idx: int | slice | tuple | np.ndarray
) -> tuple[np.ndarray, np.ndarray | None]:
) -> tuple[np.ndarray, np.ndarray | None, *Ts]:
"""
Returns volume(s) and corresponding affine(s) through fancy indexing.

Expand All @@ -118,7 +128,7 @@
raise ValueError("No data available (dataobj is None).")

affine = self.motion_affines[idx] if self.motion_affines is not None else None
return self.dataobj[..., idx], affine
return self.dataobj[..., idx], affine, *self._getextra(idx)

@classmethod
def from_filename(cls, filename: Path | str) -> BaseDataset:
Expand Down Expand Up @@ -159,9 +169,8 @@
The order of the spline interpolation.

"""
reference = namedtuple("ImageGrid", ("shape", "affine"))(
shape=self.dataobj.shape[:3], affine=self.affine
)
ImageGrid = namedtuple("ImageGrid", ("shape", "affine"))
reference = ImageGrid(shape=self.dataobj.shape[:3], affine=self.affine)

xform = Affine(matrix=affine, reference=reference)

Expand Down Expand Up @@ -227,7 +236,7 @@
compression_opts=compression_opts,
)

def to_nifti(self, filename: Path) -> None:
def to_nifti(self, filename: Path | str) -> None:
"""
Write a NIfTI file to disk.

Expand All @@ -247,7 +256,7 @@
filename: Path | str,
brainmask_file: Path | str | None = None,
motion_file: Path | str | None = None,
) -> BaseDataset:
) -> BaseDataset[()]:
"""
Load 4D data from a filename or an HDF5 file.

Expand Down Expand Up @@ -279,11 +288,11 @@
if filename.name.endswith(NFDH5_EXT):
return BaseDataset.from_filename(filename)

img = nb.load(filename)
retval = BaseDataset(dataobj=img.dataobj, affine=img.affine)
img = load_api(filename, SpatialImage)
retval: BaseDataset[()] = BaseDataset(dataobj=np.asanyarray(img.dataobj), affine=img.affine)

if brainmask_file:
mask = nb.load(brainmask_file)
mask = load_api(brainmask_file, SpatialImage)

Check warning on line 295 in src/nifreeze/data/base.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/base.py#L295

Added line #L295 was not covered by tests
retval.brainmask = np.asanyarray(mask.dataobj)

return retval
Loading
Loading