From dcb5e3407a32e2cbd9085ffbe4ce9430f8bacad0 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 9 Nov 2024 15:26:16 +0100 Subject: [PATCH 01/35] Add PCA-based compression operator (#181) Co-authored-by: Christoph Kolbitsch --- src/mrpro/operators/PCACompressionOp.py | 85 ++++++++++++++++++++++ src/mrpro/operators/__init__.py | 2 + tests/operators/test_pca_compression_op.py | 50 +++++++++++++ 3 files changed, 137 insertions(+) create mode 100644 src/mrpro/operators/PCACompressionOp.py create mode 100644 tests/operators/test_pca_compression_op.py diff --git a/src/mrpro/operators/PCACompressionOp.py b/src/mrpro/operators/PCACompressionOp.py new file mode 100644 index 000000000..ace625ea8 --- /dev/null +++ b/src/mrpro/operators/PCACompressionOp.py @@ -0,0 +1,85 @@ +"""PCA Compression Operator.""" + +import einops +import torch +from einops import repeat + +from mrpro.operators.LinearOperator import LinearOperator + + +class PCACompressionOp(LinearOperator): + """PCA based compression operator.""" + + def __init__( + self, + data: torch.Tensor, + n_components: int, + ): + """Construct a PCA based compression operator. + + The operator carries out an SVD followed by a threshold of the n_components largest values along the last + dimension of a data with shape (*other, joint_dim, compression_dim). A single SVD is carried out for everything + along joint_dim. Other are batch dimensions. + + Consider combining this operator with :class:`mrpro.operators.RearrangeOp` to make sure the data is + in the correct shape before applying. + + Parameters + ---------- + data + Data of shape (*other, joint_dim, compression_dim) to be used to find the principal components. + n_components + Number of principal components to keep along the compression_dim. + """ + super().__init__() + # different compression matrices along the *other dimensions + data = data - data.mean(-1, keepdim=True) + correlation = einops.einsum(data, data.conj(), '... joint comp1, ... joint comp2 -> ... comp1 comp2') + _, _, v = torch.svd(correlation) + # add joint_dim along which the the compression is the same + v = repeat(v, '... comp1 comp2 -> ... joint_dim comp1 comp2', joint_dim=1) + self.register_buffer('_compression_matrix', v[..., :n_components, :].clone()) + + def forward(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the compression to the data. + + Parameters + ---------- + data + data to be compressed of shape (*other, joint_dim, compression_dim) + + Returns + ------- + compressed data of shape (*other, joint_dim, n_components) + """ + try: + result = (self._compression_matrix @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix {tuple(self._compression_matrix.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) + + def adjoint(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint compression to the data. + + Parameters + ---------- + data + compressed data of shape (*other, joint_dim, n_components) + + Returns + ------- + expanded data of shape (*other, joint_dim, compression_dim) + """ + try: + result = (self._compression_matrix.mH @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix^H {tuple(self._compression_matrix.mH.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index b8c16ebfe..c22f386cd 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -14,6 +14,7 @@ from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp +from mrpro.operators.PCACompressionOp import PCACompressionOp from mrpro.operators.PhaseOp import PhaseOp from mrpro.operators.ProximableFunctionalSeparableSum import ProximableFunctionalSeparableSum from mrpro.operators.SensitivityOp import SensitivityOp @@ -41,6 +42,7 @@ "MagnitudeOp", "MultiIdentityOp", "Operator", + "PCACompressionOp", "PhaseOp", "ProximableFunctional", "ProximableFunctionalSeparableSum", diff --git a/tests/operators/test_pca_compression_op.py b/tests/operators/test_pca_compression_op.py new file mode 100644 index 000000000..a600bcecb --- /dev/null +++ b/tests/operators/test_pca_compression_op.py @@ -0,0 +1,50 @@ +"""Tests for PCA Compression Operator.""" + +import pytest +from mrpro.operators import PCACompressionOp + +from tests import RandomGenerator +from tests.helper import dotproduct_adjointness_test + + +@pytest.mark.parametrize( + ('init_data_shape', 'input_shape', 'n_components'), + [ + ((40, 10), (100, 10), 6), + ((40, 10), (3, 4, 5, 100, 10), 3), + ((3, 4, 40, 10), (3, 4, 100, 10), 6), + ((3, 4, 40, 10), (7, 3, 4, 100, 10), 3), + ], +) +def test_pca_compression_op_adjoint(init_data_shape, input_shape, n_components): + """Test adjointness of PCA Compression Op.""" + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + u = generator.complex64_tensor(input_shape) + output_shape = (*input_shape[:-1], n_components) + v = generator.complex64_tensor(output_shape) + + # Create operator and apply + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=n_components) + dotproduct_adjointness_test(pca_comp_op, u, v) + + +def test_pca_compression_op_wrong_shapes(): + """Test if Operator raises error if shape mismatch.""" + init_data_shape = (10, 6) + input_shape = (100, 3) + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + input_data = generator.complex64_tensor(input_shape) + + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=2) + + with pytest.raises(RuntimeError, match='Matrix'): + pca_comp_op(input_data) + + with pytest.raises(RuntimeError, match='Matrix.H'): + pca_comp_op.adjoint(input_data) From c268ad25a429f93d3bb33df6cb2bf0ffc06c4759 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Sat, 9 Nov 2024 16:11:57 +0100 Subject: [PATCH 02/35] Fix CartesianSamplingOp (#483) --- src/mrpro/operators/CartesianSamplingOp.py | 10 ++++++---- tests/operators/test_cartesian_sampling_op.py | 17 +++++++++++++++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 64068a5d1..7a51924b1 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -47,26 +47,28 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> kx_idx = ktraj_tensor[-1, ...].round().to(dtype=torch.int64) + sorted_grid_shape.x // 2 else: sorted_grid_shape.x = ktraj_tensor.shape[-1] - kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k1 k2 k0', other=1, k2=1, k1=1) + kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k2 k1 k0', other=1, k2=1, k1=1) if traj_type_kzyx[-2] == TrajType.ONGRID: # ky ky_idx = ktraj_tensor[-2, ...].round().to(dtype=torch.int64) + sorted_grid_shape.y // 2 else: sorted_grid_shape.y = ktraj_tensor.shape[-2] - ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k1 k2 k0', other=1, k2=1, k0=1) + ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k2 k1 k0', other=1, k2=1, k0=1) if traj_type_kzyx[-3] == TrajType.ONGRID: # kz kz_idx = ktraj_tensor[-3, ...].round().to(dtype=torch.int64) + sorted_grid_shape.z // 2 else: sorted_grid_shape.z = ktraj_tensor.shape[-3] - kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k1 k2 k0', other=1, k1=1, k0=1) + kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k2 k1 k0', other=1, k1=1, k0=1) # 1D indices into a flattened tensor. kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)') self.register_buffer('_fft_idx', kidx) # we can skip the indexing if the data is already sorted - self._needs_indexing = not torch.all(torch.diff(kidx) == 1) + self._needs_indexing = ( + not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + ) self._trajectory_shape = traj.broadcasted_shape self._sorted_grid_shape = sorted_grid_shape diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 6a1120e79..7caa13e7c 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -2,6 +2,7 @@ import pytest import torch +from einops import rearrange from mrpro.data import KTrajectory, SpatialDimension from mrpro.operators import CartesianSamplingOp @@ -59,6 +60,9 @@ def test_cart_sampling_op_data_match(): 'regular_undersampling', 'random_undersampling', 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', ], ) def test_cart_sampling_op_fwd_adj(sampling): @@ -70,8 +74,8 @@ def test_cart_sampling_op_fwd_adj(sampling): nky = (2, 1, 40, 1) nkz = (2, 20, 1, 1) sx = 'uf' - sy = 'uf' - sz = 'uf' + sy = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' + sz = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor() # Subsample data and trajectory @@ -94,6 +98,15 @@ def test_cart_sampling_op_fwd_adj(sampling): for traj_one_other in trajectory_tensor.unbind(1) ] trajectory = KTrajectory.from_tensor(torch.stack(traj_list, dim=1)) + case 'cartesian_and_non_cartesian': + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0_undersampling': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + random_idx = torch.randperm(trajectory_tensor.shape[-1]) + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]]) case _: raise NotImplementedError(f'Test {sampling} not implemented.') From 54674a95b59165cbf5cb21e1a71d1bb616921554 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 11 Nov 2024 01:55:28 +0100 Subject: [PATCH 03/35] Add reduce_view (#476) Undoes expand, i.e. replaces stride 0 dimensions by size=1 dimensions --- src/mrpro/utils/__init__.py | 3 ++- src/mrpro/utils/reshape.py | 32 ++++++++++++++++++++++++++++++++ tests/utils/test_reshape.py | 32 +++++++++++++++++++++++++++++--- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index c09071f4b..b16fae37a 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -5,10 +5,11 @@ from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop from mrpro.utils.modify_acq_info import modify_acq_info from mrpro.utils.split_idx import split_idx -from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right +from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view __all__ = [ "broadcast_right", "modify_acq_info", + "reduce_view", "remove_repeat", "slice_profiles", "smap", diff --git a/src/mrpro/utils/reshape.py b/src/mrpro/utils/reshape.py index 39d12e51f..31d495afd 100644 --- a/src/mrpro/utils/reshape.py +++ b/src/mrpro/utils/reshape.py @@ -1,5 +1,7 @@ """Tensor reshaping utilities.""" +from collections.abc import Sequence + import torch @@ -67,3 +69,33 @@ def broadcast_right(*x: torch.Tensor) -> tuple[torch.Tensor, ...]: max_dim = max(el.ndim for el in x) unsqueezed = torch.broadcast_tensors(*(unsqueeze_right(el, max_dim - el.ndim) for el in x)) return unsqueezed + + +def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torch.Tensor: + """Reduce expanded dimensions in a view to singletons. + + Reduce either all or specific dimensions to a singleton if it + points to the same memory address. + This undoes expand. + + Parameters + ---------- + x + input tensor + dim + only reduce expanded dimensions in the specified dimensions. + If None, reduce all expanded dimensions. + """ + if dim is None: + dim_: Sequence[int] = range(x.ndim) + elif isinstance(dim, Sequence): + dim_ = [d % x.ndim for d in dim] + else: + dim_ = [dim % x.ndim] + + stride = x.stride() + newsize = [ + 1 if stride == 0 and d in dim_ else oldsize + for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True)) + ] + return torch.as_strided(x, newsize, stride) diff --git a/tests/utils/test_reshape.py b/tests/utils/test_reshape.py index 60a0dc5e3..dd57b8feb 100644 --- a/tests/utils/test_reshape.py +++ b/tests/utils/test_reshape.py @@ -1,7 +1,9 @@ """Tests for reshaping utilities.""" import torch -from mrpro.utils import broadcast_right, unsqueeze_left, unsqueeze_right +from mrpro.utils import broadcast_right, reduce_view, unsqueeze_left, unsqueeze_right + +from tests import RandomGenerator def test_broadcast_right(): @@ -12,7 +14,7 @@ def test_broadcast_right(): def test_unsqueeze_left(): - """Test unsqueeze left""" + """Test unsqueeze_left""" tensor = torch.ones(1, 2, 3) unsqueezed = unsqueeze_left(tensor, 2) assert unsqueezed.shape == (1, 1, 1, 2, 3) @@ -20,8 +22,32 @@ def test_unsqueeze_left(): def test_unsqueeze_right(): - """Test unsqueeze right""" + """Test unsqueeze_right""" tensor = torch.ones(1, 2, 3) unsqueezed = unsqueeze_right(tensor, 2) assert unsqueezed.shape == (1, 2, 3, 1, 1) assert torch.equal(tensor.ravel(), unsqueezed.ravel()) + + +def test_reduce_view(): + """Test reduce_view""" + + tensor = RandomGenerator(0).float32_tensor((1, 2, 3, 1, 1, 1)) + tensor = tensor.expand(1, 2, 3, 4, 1, 1).contiguous() # this cannot be removed + tensor = tensor.expand(7, 2, 3, 4, 5, 6) + + reduced_all = reduce_view(tensor) + assert reduced_all.shape == (1, 2, 3, 4, 1, 1) + assert torch.equal(reduced_all.expand_as(tensor), tensor) + + reduced_two = reduce_view(tensor, (0, -1)) + assert reduced_two.shape == (1, 2, 3, 4, 5, 1) + assert torch.equal(reduced_two.expand_as(tensor), tensor) + + reduced_one_neg = reduce_view(tensor, -1) + assert reduced_one_neg.shape == (7, 2, 3, 4, 5, 1) + assert torch.equal(reduced_one_neg.expand_as(tensor), tensor) + + reduced_one_pos = reduce_view(tensor, 0) + assert reduced_one_pos.shape == (1, 2, 3, 4, 5, 6) + assert torch.equal(reduced_one_pos.expand_as(tensor), tensor) From f0f91c3ff91296a07e6bb7a184684f6674a91795 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 11 Nov 2024 11:35:57 +0100 Subject: [PATCH 04/35] Add apply_ to dataclasses (#505) Applies a function to all children of a dataclass --- src/mrpro/data/AcqInfo.py | 14 ++---- src/mrpro/data/MoveDataMixin.py | 68 ++++++++++++++++++++++------ src/mrpro/data/SpatialDimension.py | 55 +++++----------------- tests/data/test_movedatamixin.py | 19 ++++++++ tests/data/test_spatial_dimension.py | 23 ---------- 5 files changed, 89 insertions(+), 90 deletions(-) diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index 83f752a57..a66224de1 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -1,6 +1,6 @@ """Acquisition information dataclass.""" -from collections.abc import Callable, Sequence +from collections.abc import Sequence from dataclasses import dataclass import ismrmrd @@ -206,17 +206,13 @@ def tensor_2d(data: np.ndarray) -> torch.Tensor: data_tensor = data_tensor[None, None] return data_tensor - def spatialdimension_2d( - data: np.ndarray, conversion: Callable[[torch.Tensor], torch.Tensor] | None = None - ) -> SpatialDimension[torch.Tensor]: + def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: # Ensure spatial dimension is (k1*k2*other, 1, 3) if data.ndim != 2: raise ValueError('Spatial dimension is expected to be of shape (N,3)') data = data[:, None, :] # all spatial dimensions are float32 - return ( - SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))).apply_(conversion) - ) + return SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))) acq_idx = AcqIdx( k1=tensor(idx['kspace_encode_step_1']), @@ -251,10 +247,10 @@ def spatialdimension_2d( flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), number_of_samples=tensor_2d(headers['number_of_samples']), - patient_table_position=spatialdimension_2d(headers['patient_table_position'], mm_to_m), + patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), phase_dir=spatialdimension_2d(headers['phase_dir']), physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), - position=spatialdimension_2d(headers['position'], mm_to_m), + position=spatialdimension_2d(headers['position']).apply_(mm_to_m), read_dir=spatialdimension_2d(headers['read_dir']), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index f3f147260..8d977d0a6 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -1,13 +1,13 @@ """MoveDataMixin.""" import dataclasses -from collections.abc import Iterator +from collections.abc import Callable, Iterator from copy import copy as shallowcopy from copy import deepcopy -from typing import ClassVar, TypeAlias +from typing import ClassVar, TypeAlias, cast import torch -from typing_extensions import Any, Protocol, Self, overload, runtime_checkable +from typing_extensions import Any, Protocol, Self, TypeVar, overload, runtime_checkable class InconsistentDeviceError(ValueError): # noqa: D101 @@ -22,6 +22,9 @@ class DataclassInstance(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] +T = TypeVar('T') + + class MoveDataMixin: """Move dataclass fields to cpu/gpu and convert dtypes.""" @@ -151,7 +154,6 @@ def _to( copy: bool = False, memo: dict | None = None, ) -> Self: - new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self """Move data to device and convert dtype if necessary. This method is called by .to(), .cuda(), .cpu(), .double(), and so on. @@ -179,6 +181,8 @@ def _to( memo A dictionary to keep track of already converted objects to avoid multiple conversions. """ + new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self + if memo is None: memo = {} @@ -219,26 +223,62 @@ def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin: memo=memo, ) - converted: Any - for name, data in new._items(): - if id(data) in memo: - object.__setattr__(new, name, memo[id(data)]) - continue + def _convert(data: T) -> T: + converted: Any # https://github.com/python/mypy/issues/10817 if isinstance(data, torch.Tensor): converted = _tensor_to(data) elif isinstance(data, MoveDataMixin): converted = _mixin_to(data) elif isinstance(data, torch.nn.Module): converted = _module_to(data) - elif copy: - converted = deepcopy(data) else: converted = data - memo[id(data)] = converted - # this works even if new is frozen - object.__setattr__(new, name, converted) + return cast(T, converted) + + # manual recursion allows us to do the copy only once + new.apply_(_convert, memo=memo, recurse=False) return new + def apply_( + self: Self, + function: Callable[[Any], Any] | None = None, + *, + memo: dict[int, Any] | None = None, + recurse: bool = True, + ) -> Self: + """Apply a function to all children in-place. + + Parameters + ---------- + function + The function to apply to all fields. None is interpreted as a no-op. + memo + A dictionary to keep track of objects that the function has already been applied to, + to avoid multiple applications. This is useful if the object has a circular reference. + recurse + If True, the function will be applied to all children that are MoveDataMixin instances. + """ + applied: Any + + if memo is None: + memo = {} + + if function is None: + return self + + for name, data in self._items(): + if id(data) in memo: + # this works even if self is frozen + object.__setattr__(self, name, memo[id(data)]) + continue + if recurse and isinstance(data, MoveDataMixin): + applied = data.apply_(function, memo=memo) + else: + applied = function(data) + memo[id(data)] = applied + object.__setattr__(self, name, applied) + return self + def cuda( self, device: torch.device | str | int | None = None, diff --git a/src/mrpro/data/SpatialDimension.py b/src/mrpro/data/SpatialDimension.py index b5f3dfd27..46b1db89a 100644 --- a/src/mrpro/data/SpatialDimension.py +++ b/src/mrpro/data/SpatialDimension.py @@ -3,14 +3,13 @@ from __future__ import annotations from collections.abc import Callable -from copy import deepcopy from dataclasses import dataclass from typing import Generic, get_args import numpy as np import torch from numpy.typing import ArrayLike -from typing_extensions import Any, Protocol, TypeVar, overload +from typing_extensions import Protocol, Self, TypeVar, overload import mrpro.utils.typing as type_utils from mrpro.data.MoveDataMixin import MoveDataMixin @@ -109,6 +108,16 @@ def from_array_zyx( return SpatialDimension(z, y, x) + def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self: + """Apply a function to each z, y, x (in-place). + + Parameters + ---------- + function + function to apply + """ + return super(SpatialDimension, self).apply_(function) + @property def zyx(self) -> tuple[T_co, T_co, T_co]: """Return a z,y,x tuple.""" @@ -134,48 +143,6 @@ def __setitem__(self: SpatialDimension[T_co_vector], idx: type_utils.TorchIndexe self.y[idx] = other.y self.x[idx] = other.x - def apply_(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: - """Apply function to each of x,y,z in-place. - - Parameters - ---------- - func - function to apply to each of x,y,z - None is interpreted as the identity function. - """ - if func is not None: - self.z = func(self.z) - self.y = func(self.y) - self.x = func(self.x) - return self - - def apply(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: - """Apply function to each of x,y,z. - - Parameters - ---------- - func - function to apply to each of x,y,z - None is interpreted as the identity function. - """ - - def func_(x: Any) -> T_co: # noqa: ANN401 - if isinstance(x, torch.Tensor): - # use clone for autograd - x = x.clone() - else: - x = deepcopy(x) - if func is None: - return x - else: - return func(x) - - return self.__class__(func_(self.z), func_(self.y), func_(self.x)) - - def clone(self: SpatialDimension[T_co]) -> SpatialDimension[T_co]: - """Return a deep copy of the SpatialDimension.""" - return self.apply() - @overload def __mul__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ... diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py index 06f55a4dc..3feb091de 100644 --- a/tests/data/test_movedatamixin.py +++ b/tests/data/test_movedatamixin.py @@ -23,6 +23,7 @@ class A(MoveDataMixin): """Test class A.""" floattensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0)) + floattensor2: torch.Tensor = field(default_factory=lambda: torch.tensor(-1.0)) complextensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.complex64)) inttensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1, dtype=torch.int32)) booltensor: torch.Tensor = field(default_factory=lambda: torch.tensor(True)) @@ -204,3 +205,21 @@ def testchild(attribute, expected_dtype): assert original is not new, 'original and new should not be the same object' assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared' + + +def test_movedatamixin_apply(): + """Tests apply_ method of MoveDataMixin.""" + data = B() + # make one of the parameters shared to test memo behavior + data.child.floattensor2 = data.child.floattensor + original = data.clone() + + def multiply_by_2(obj): + if isinstance(obj, torch.Tensor): + return obj * 2 + return obj + + data.apply_(multiply_by_2) + torch.testing.assert_close(data.floattensor, original.floattensor * 2) + torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2 * 2) + assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared' diff --git a/tests/data/test_spatial_dimension.py b/tests/data/test_spatial_dimension.py index cd46854b9..afafece04 100644 --- a/tests/data/test_spatial_dimension.py +++ b/tests/data/test_spatial_dimension.py @@ -115,29 +115,6 @@ def conversion(x: torch.Tensor) -> torch.Tensor: assert torch.equal(spatial_dimension_inplace.z, z) -def test_spatial_dimension_apply(): - """Test apply (out of place)""" - - def conversion(x: torch.Tensor) -> torch.Tensor: - assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor' - return x.swapaxes(0, 1).square() - - xyz = RandomGenerator(0).float32_tensor((1, 2, 3)) - spatial_dimension = SpatialDimension.from_array_xyz(xyz.numpy()) - spatial_dimension_outofplace = spatial_dimension.apply().apply(conversion) - - assert spatial_dimension_outofplace is not spatial_dimension - - assert isinstance(spatial_dimension_outofplace.x, torch.Tensor) - assert isinstance(spatial_dimension_outofplace.y, torch.Tensor) - assert isinstance(spatial_dimension_outofplace.z, torch.Tensor) - - x, y, z = conversion(xyz).unbind(-1) - assert torch.equal(spatial_dimension_outofplace.x, x) - assert torch.equal(spatial_dimension_outofplace.y, y) - assert torch.equal(spatial_dimension_outofplace.z, z) - - def test_spatial_dimension_zyx(): """Test the zyx tuple property""" z, y, x = (2, 3, 4) From a96b9c61f2c667f439124f779e399f374a75c646 Mon Sep 17 00:00:00 2001 From: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:01:46 +0100 Subject: [PATCH 05/35] Fix installation from TestPyPi in workflow (#499) --- .github/workflows/deployment.yml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/deployment.yml b/.github/workflows/deployment.yml index dd900496e..a55b7860d 100644 --- a/.github/workflows/deployment.yml +++ b/.github/workflows/deployment.yml @@ -94,7 +94,15 @@ jobs: run: | VERSION=${{ needs.build-testpypi-package.outputs.version }} SUFFIX=${{ needs.build-testpypi-package.outputs.suffix }} - python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ + for i in {1..3}; do + if python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/; then + echo "Package installed successfully." + break + else + echo "Attempt $i failed. Retrying in 10 seconds..." + sleep 10 + fi + done build-pypi-package: name: Build Package for PyPI From 191ab06b35ecae890febb34d0d2841b8fe2cd4c5 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 11 Nov 2024 17:45:56 +0100 Subject: [PATCH 06/35] Remove check-docstring-first pre-commit hook (#508) --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 95d14317a..4790095f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,6 @@ repos: rev: v5.0.0 hooks: - id: check-added-large-files - - id: check-docstring-first - id: check-merge-conflict - id: check-yaml - id: check-toml From 6c54e31ac22e80cfbafad88102b4ca1ab30d8241 Mon Sep 17 00:00:00 2001 From: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> Date: Tue, 12 Nov 2024 00:27:11 +0100 Subject: [PATCH 07/35] Revert NamedTemporaryFile ContextManager in example notebooks (#500) --- .pre-commit-config.yaml | 16 +++++++------- examples/direct_reconstruction.ipynb | 8 +++---- examples/direct_reconstruction.py | 8 +++---- examples/iterative_sense_reconstruction.ipynb | 8 +++---- examples/iterative_sense_reconstruction.py | 8 +++---- examples/pulseq_2d_radial_golden_angle.ipynb | 22 +++++++++---------- examples/pulseq_2d_radial_golden_angle.py | 17 +++++++------- ...rized_iterative_sense_reconstruction.ipynb | 8 +++---- ...ularized_iterative_sense_reconstruction.py | 8 +++---- examples/ruff.toml | 1 + 10 files changed, 51 insertions(+), 53 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4790095f9..303fd43fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,11 +53,11 @@ repos: - "--extra-index-url=https://pypi.python.org/simple" ci: - autofix_commit_msg: | - [pre-commit] auto fixes from pre-commit hooks - autofix_prs: false - autoupdate_branch: '' - autoupdate_commit_msg: '[pre-commit] pre-commit autoupdate' - autoupdate_schedule: monthly - skip: [mypy] - submodules: false + autofix_commit_msg: | + [pre-commit] auto fixes from pre-commit hooks + autofix_prs: false + autoupdate_branch: "" + autoupdate_commit_msg: "[pre-commit] pre-commit autoupdate" + autoupdate_schedule: monthly + skip: [mypy] + submodules: false diff --git a/examples/direct_reconstruction.ipynb b/examples/direct_reconstruction.ipynb index 1e4e74c9c..3b6dc930e 100644 --- a/examples/direct_reconstruction.ipynb +++ b/examples/direct_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { diff --git a/examples/direct_reconstruction.py b/examples/direct_reconstruction.py index 5d55812c9..7672aa7e7 100644 --- a/examples/direct_reconstruction.py +++ b/examples/direct_reconstruction.py @@ -11,10 +11,10 @@ import requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/iterative_sense_reconstruction.ipynb b/examples/iterative_sense_reconstruction.ipynb index f612d7522..87249b2fb 100644 --- a/examples/iterative_sense_reconstruction.ipynb +++ b/examples/iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { diff --git a/examples/iterative_sense_reconstruction.py b/examples/iterative_sense_reconstruction.py index ba5e6a01a..6d0bc49a5 100644 --- a/examples/iterative_sense_reconstruction.py +++ b/examples/iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/pulseq_2d_radial_golden_angle.ipynb b/examples/pulseq_2d_radial_golden_angle.ipynb index bcb4482a1..52e0310bb 100644 --- a/examples/pulseq_2d_radial_golden_angle.ipynb +++ b/examples/pulseq_2d_radial_golden_angle.ipynb @@ -33,14 +33,13 @@ "cell_type": "code", "execution_count": null, "id": "d16f41f1", - "metadata": { - "lines_to_next_cell": 2 - }, + "metadata": {}, "outputs": [], "source": [ "# define zenodo records URL and create a temporary directory and h5-file\n", "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", - "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'\n", + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')" ] }, { @@ -51,10 +50,9 @@ "outputs": [], "source": [ "# Download raw data using requests\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { @@ -127,10 +125,10 @@ "# download the sequence file from zenodo\n", "zenodo_url = 'https://zenodo.org/records/10868061/files/'\n", "seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq'\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file:\n", - " response = requests.get(zenodo_url + seq_fname, timeout=30)\n", - " seq_file.write(response.content)\n", - " seq_file.flush()" + "seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq')\n", + "response = requests.get(zenodo_url + seq_fname, timeout=30)\n", + "seq_file.write(response.content)\n", + "seq_file.flush()" ] }, { diff --git a/examples/pulseq_2d_radial_golden_angle.py b/examples/pulseq_2d_radial_golden_angle.py index f4db5217a..3f857c382 100644 --- a/examples/pulseq_2d_radial_golden_angle.py +++ b/examples/pulseq_2d_radial_golden_angle.py @@ -19,14 +19,13 @@ # define zenodo records URL and create a temporary directory and h5-file zenodo_url = 'https://zenodo.org/records/10854057/files/' fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' - +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') # %% # Download raw data using requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction using KTrajectoryIsmrmrd @@ -63,10 +62,10 @@ # download the sequence file from zenodo zenodo_url = 'https://zenodo.org/records/10868061/files/' seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq' -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file: - response = requests.get(zenodo_url + seq_fname, timeout=30) - seq_file.write(response.content) - seq_file.flush() +seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') +response = requests.get(zenodo_url + seq_fname, timeout=30) +seq_file.write(response.content) +seq_file.flush() # %% # Read raw data and calculate trajectory using KTrajectoryPulseq diff --git a/examples/regularized_iterative_sense_reconstruction.ipynb b/examples/regularized_iterative_sense_reconstruction.ipynb index 0a6743161..6b1c2704b 100644 --- a/examples/regularized_iterative_sense_reconstruction.ipynb +++ b/examples/regularized_iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { diff --git a/examples/regularized_iterative_sense_reconstruction.py b/examples/regularized_iterative_sense_reconstruction.py index 2ab7ba033..e41dc4ac5 100644 --- a/examples/regularized_iterative_sense_reconstruction.py +++ b/examples/regularized_iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/ruff.toml b/examples/ruff.toml index 11a1e6167..1bb114755 100644 --- a/examples/ruff.toml +++ b/examples/ruff.toml @@ -5,4 +5,5 @@ lint.extend-ignore = [ "T20", #print "E402", #module-import-not-at-top-of-file "S101", #assert + "SIM115", #context manager for opening files ] From a89df61455675201b82e86c19b4b8b9743a5068b Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 12 Nov 2024 08:58:11 +0100 Subject: [PATCH 08/35] Adapt KHeader parameters (#506) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/mrpro/data/AcqInfo.py | 41 ++++----- src/mrpro/data/KHeader.py | 48 ++++------ src/mrpro/data/TrajectoryDescription.py | 29 ------ src/mrpro/data/__init__.py | 4 +- src/mrpro/data/_kdata/KData.py | 15 ++-- src/mrpro/data/_kdata/KDataRearrangeMixin.py | 10 +-- src/mrpro/data/_kdata/KDataSelectMixin.py | 9 +- src/mrpro/data/_kdata/KDataSplitMixin.py | 28 +++--- src/mrpro/utils/__init__.py | 7 +- src/mrpro/utils/modify_acq_info.py | 35 -------- src/mrpro/utils/unit_conversion.py | 94 ++++++++++++++++++++ tests/conftest.py | 6 +- tests/data/test_kdata.py | 34 +++++-- tests/utils/test_modify_acq_info.py | 18 ---- tests/utils/test_unit_conversion.py | 82 +++++++++++++++++ 15 files changed, 278 insertions(+), 182 deletions(-) delete mode 100644 src/mrpro/data/TrajectoryDescription.py delete mode 100644 src/mrpro/utils/modify_acq_info.py create mode 100644 src/mrpro/utils/unit_conversion.py delete mode 100644 tests/utils/test_modify_acq_info.py create mode 100644 tests/utils/test_unit_conversion.py diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index a66224de1..f5d677f97 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -6,23 +6,24 @@ import ismrmrd import numpy as np import torch -from typing_extensions import Self, TypeVar +from einops import rearrange +from typing_extensions import Self from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.utils.unit_conversion import mm_to_m -# Conversion functions for units -T = TypeVar('T', float, torch.Tensor) +def rearrange_acq_info_fields(field: object, pattern: str, **axes_lengths: dict[str, int]) -> object: + """Change the shape of the fields in AcqInfo.""" + if isinstance(field, Rotation): + return Rotation.from_matrix(rearrange(field.as_matrix(), pattern, **axes_lengths)) -def ms_to_s(ms: T) -> T: - """Convert ms to s.""" - return ms / 1000 + if isinstance(field, torch.Tensor): + return rearrange(field, pattern, **axes_lengths) - -def mm_to_m(m: T) -> T: - """Convert mm to m.""" - return m / 1000 + return field @dataclass(slots=True) @@ -121,30 +122,24 @@ class AcqInfo(MoveDataMixin): number_of_samples: torch.Tensor """Number of sample points per readout (readouts may have different number of sample points).""" + orientation: Rotation + """Rotation describing the orientation of the readout, phase and slice encoding direction.""" + patient_table_position: SpatialDimension[torch.Tensor] """Offset position of the patient table, in LPS coordinates [m].""" - phase_dir: SpatialDimension[torch.Tensor] - """Directional cosine of phase encoding (2D).""" - physiology_time_stamp: torch.Tensor """Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units""" position: SpatialDimension[torch.Tensor] """Center of the excited volume, in LPS coordinates relative to isocenter [m].""" - read_dir: SpatialDimension[torch.Tensor] - """Directional cosine of readout/frequency encoding.""" - sample_time_us: torch.Tensor """Readout bandwidth, as time between samples [us].""" scan_counter: torch.Tensor """Zero-indexed incrementing counter for readouts.""" - slice_dir: SpatialDimension[torch.Tensor] - """Directional cosine of slice normal, i.e. cross-product of read_dir and phase_dir.""" - trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists. """Dimensionality of the k-space trajectory vector.""" @@ -247,14 +242,16 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), number_of_samples=tensor_2d(headers['number_of_samples']), + orientation=Rotation.from_directions( + spatialdimension_2d(headers['slice_dir']), + spatialdimension_2d(headers['phase_dir']), + spatialdimension_2d(headers['read_dir']), + ), patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), - phase_dir=spatialdimension_2d(headers['phase_dir']), physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), position=spatialdimension_2d(headers['position']).apply_(mm_to_m), - read_dir=spatialdimension_2d(headers['read_dir']), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), - slice_dir=spatialdimension_2d(headers['slice_dir']), trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above user_float=tensor_2d(headers['user_float']), user_int=tensor_2d(headers['user_int']), diff --git a/src/mrpro/data/KHeader.py b/src/mrpro/data/KHeader.py index 488c1d9cd..dea12e28b 100644 --- a/src/mrpro/data/KHeader.py +++ b/src/mrpro/data/KHeader.py @@ -13,12 +13,12 @@ from typing_extensions import Self from mrpro.data import enums -from mrpro.data.AcqInfo import AcqInfo, mm_to_m, ms_to_s +from mrpro.data.AcqInfo import AcqInfo from mrpro.data.EncodingLimits import EncodingLimits from mrpro.data.MoveDataMixin import MoveDataMixin from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.TrajectoryDescription import TrajectoryDescription from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues +from mrpro.utils.unit_conversion import mm_to_m, ms_to_s if TYPE_CHECKING: # avoid circular imports by importing only when type checking @@ -40,9 +40,6 @@ class KHeader(MoveDataMixin): trajectory: KTrajectoryCalculator """Function to calculate the k-space trajectory.""" - b0: float - """Magnetic field strength [T].""" - encoding_limits: EncodingLimits """K-space encoding limits.""" @@ -61,12 +58,9 @@ class KHeader(MoveDataMixin): acq_info: AcqInfo """Information of the acquisitions (i.e. readout lines).""" - h1_freq: float + lamor_frequency_proton: float """Lamor frequency of hydrogen nuclei [Hz].""" - n_coils: int | None = None - """Number of receiver coils.""" - datetime: datetime.datetime | None = None """Date and time of acquisition.""" @@ -88,7 +82,7 @@ class KHeader(MoveDataMixin): echo_train_length: int = 1 """Number of echoes in a multi-echo acquisition.""" - seq_type: str = UNKNOWN + sequence_type: str = UNKNOWN """Type of sequence.""" model: str = UNKNOWN @@ -100,16 +94,13 @@ class KHeader(MoveDataMixin): protocol_name: str = UNKNOWN """Name of the acquisition protocol.""" - misc: dict = dataclasses.field(default_factory=dict) # do not use {} here! - """Dictionary with miscellaneous parameters.""" - calibration_mode: enums.CalibrationMode = enums.CalibrationMode.OTHER """Mode of how calibration data is acquired. """ interleave_dim: enums.InterleavingDimension = enums.InterleavingDimension.OTHER """Interleaving dimension.""" - traj_type: enums.TrajectoryType = enums.TrajectoryType.OTHER + trajectory_type: enums.TrajectoryType = enums.TrajectoryType.OTHER """Type of trajectory.""" measurement_id: str = UNKNOWN @@ -118,8 +109,9 @@ class KHeader(MoveDataMixin): patient_name: str = UNKNOWN """Name of the patient.""" - trajectory_description: TrajectoryDescription = dataclasses.field(default_factory=TrajectoryDescription) - """Description of the trajectory.""" + _misc: dict = dataclasses.field(default_factory=dict) # do not use {} here! + """Dictionary with miscellaneous parameters. These parameters are for information purposes only. Reconstruction + algorithms should not rely on them.""" @property def fa_degree(self) -> torch.Tensor | None: @@ -160,17 +152,14 @@ def from_ismrmrd( enc: ismrmrdschema.encodingType = header.encoding[encoding_number] # These are guaranteed to exist - parameters = {'h1_freq': header.experimentalConditions.H1resonanceFrequency_Hz, 'acq_info': acq_info} + parameters = { + 'lamor_frequency_proton': header.experimentalConditions.H1resonanceFrequency_Hz, + 'acq_info': acq_info, + } if defaults is not None: parameters.update(defaults) - if ( - header.acquisitionSystemInformation is not None - and header.acquisitionSystemInformation.receiverChannels is not None - ): - parameters['n_coils'] = header.acquisitionSystemInformation.receiverChannels - if header.sequenceParameters is not None: if header.sequenceParameters.TR: parameters['tr'] = ms_to_s(torch.as_tensor(header.sequenceParameters.TR)) @@ -184,7 +173,7 @@ def from_ismrmrd( parameters['echo_spacing'] = ms_to_s(torch.as_tensor(header.sequenceParameters.echo_spacing)) if header.sequenceParameters.sequence_type is not None: - parameters['seq_type'] = header.sequenceParameters.sequence_type + parameters['sequence_type'] = header.sequenceParameters.sequence_type if enc.reconSpace is not None: parameters['recon_fov'] = SpatialDimension[float].from_xyz(enc.reconSpace.fieldOfView_mm).apply_(mm_to_m) @@ -212,7 +201,7 @@ def from_ismrmrd( ) if enc.trajectory is not None: - parameters['traj_type'] = enums.TrajectoryType(enc.trajectory.value) + parameters['trajectory_type'] = enums.TrajectoryType(enc.trajectory.value) # Either use the series or study time if available if header.measurementInformation is not None and header.measurementInformation.seriesTime is not None: @@ -245,15 +234,8 @@ def from_ismrmrd( if header.acquisitionSystemInformation.systemModel is not None: parameters['model'] = header.acquisitionSystemInformation.systemModel - if header.acquisitionSystemInformation.systemFieldStrength_T is not None: - parameters['b0'] = header.acquisitionSystemInformation.systemFieldStrength_T - - # estimate b0 from h1_freq if not given - if 'b0' not in parameters: - parameters['b0'] = parameters['h1_freq'] / 4258e4 - # Dump everything into misc - parameters['misc'] = dataclasses.asdict(header) + parameters['_misc'] = dataclasses.asdict(header) if overwrite is not None: parameters.update(overwrite) diff --git a/src/mrpro/data/TrajectoryDescription.py b/src/mrpro/data/TrajectoryDescription.py deleted file mode 100644 index 801811005..000000000 --- a/src/mrpro/data/TrajectoryDescription.py +++ /dev/null @@ -1,29 +0,0 @@ -"""TrajectoryDescription dataclass.""" - -import dataclasses -from dataclasses import dataclass - -from ismrmrd.xsd.ismrmrdschema.ismrmrd import trajectoryDescriptionType -from typing_extensions import Self - - -@dataclass(slots=True) -class TrajectoryDescription: - """TrajectoryDescription dataclass.""" - - identifier: str = '' - user_parameter_long: dict[str, int] = dataclasses.field(default_factory=dict) - user_parameter_double: dict[str, float] = dataclasses.field(default_factory=dict) - user_parameter_string: dict[str, str] = dataclasses.field(default_factory=dict) - comment: str = '' - - @classmethod - def from_ismrmrd(cls, trajectory_description: trajectoryDescriptionType) -> Self: - """Create TrajectoryDescription from ismrmrd traj description.""" - return cls( - user_parameter_long={p.name: int(p.value) for p in trajectory_description.userParameterLong}, - user_parameter_double={p.name: float(p.value) for p in trajectory_description.userParameterDouble}, - user_parameter_string={p.name: str(p.value) for p in trajectory_description.userParameterString}, - comment=trajectory_description.comment or '', - identifier=trajectory_description.identifier or '', - ) diff --git a/src/mrpro/data/__init__.py b/src/mrpro/data/__init__.py index b5034d668..d5667a5bc 100644 --- a/src/mrpro/data/__init__.py +++ b/src/mrpro/data/__init__.py @@ -16,7 +16,6 @@ from mrpro.data.QHeader import QHeader from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.TrajectoryDescription import TrajectoryDescription __all__ = [ "AcqIdx", "AcqInfo", @@ -37,8 +36,7 @@ "QHeader", "Rotation", "SpatialDimension", - "TrajectoryDescription", "acq_filters", "enums", "traj_calculators" -] \ No newline at end of file +] diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index 409e8aac9..57af617bc 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -18,15 +18,15 @@ from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin from mrpro.data.acq_filters import is_image_acquisition -from mrpro.data.AcqInfo import AcqInfo +from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.Rotation import Rotation from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd -from mrpro.utils import modify_acq_info KDIM_SORT_LABELS = ( 'k1', @@ -200,10 +200,13 @@ def from_file( sort_idx = np.lexsort(acq_indices) # torch does not have lexsort as of pytorch 2.2 (March 2024) # Finally, reshape and sort the tensors in acqinfo and acqinfo.idx, and kdata. - def sort_and_reshape_tensor_fields(input_tensor: torch.Tensor): - return rearrange(input_tensor[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2) - - kheader.acq_info = modify_acq_info(sort_and_reshape_tensor_fields, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2 + ) + if isinstance(field, torch.Tensor | Rotation) + else field + ) kdata = rearrange(kdata[sort_idx], '(other k2 k1) coils k0 -> other coils k2 k1 k0', k1=n_k1, k2=n_k2) # Calculate trajectory and check if it matches the kdata shape diff --git a/src/mrpro/data/_kdata/KDataRearrangeMixin.py b/src/mrpro/data/_kdata/KDataRearrangeMixin.py index 05bab7681..23a58dea6 100644 --- a/src/mrpro/data/_kdata/KDataRearrangeMixin.py +++ b/src/mrpro/data/_kdata/KDataRearrangeMixin.py @@ -6,8 +6,7 @@ from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.AcqInfo import AcqInfo -from mrpro.utils import modify_acq_info +from mrpro.data.AcqInfo import rearrange_acq_info_fields class KDataRearrangeMixin(_KDataProtocol): @@ -35,9 +34,8 @@ def rearrange_k2_k1_into_k1(self: Self) -> Self: kheader = copy.deepcopy(self.header) # Update shape of acquisition info index - def reshape_acq_info(info: AcqInfo): - return rearrange(info, 'other k2 k1 ... -> other 1 (k2 k1) ...') - - kheader.acq_info = modify_acq_info(reshape_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...') + ) return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) diff --git a/src/mrpro/data/_kdata/KDataSelectMixin.py b/src/mrpro/data/_kdata/KDataSelectMixin.py index fb0e02aa1..8f8a452cf 100644 --- a/src/mrpro/data/_kdata/KDataSelectMixin.py +++ b/src/mrpro/data/_kdata/KDataSelectMixin.py @@ -7,7 +7,7 @@ from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.utils import modify_acq_info +from mrpro.data.Rotation import Rotation class KDataSelectMixin(_KDataProtocol): @@ -51,10 +51,9 @@ def select_other_subset( other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0) # Adapt header - def select_acq_info(info: torch.Tensor): - return info[other_idx, ...] - - kheader.acq_info = modify_acq_info(select_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field + ) # Select data kdat = self.data[other_idx, ...] diff --git a/src/mrpro/data/_kdata/KDataSplitMixin.py b/src/mrpro/data/_kdata/KDataSplitMixin.py index d2c641125..c28004af4 100644 --- a/src/mrpro/data/_kdata/KDataSplitMixin.py +++ b/src/mrpro/data/_kdata/KDataSplitMixin.py @@ -1,15 +1,17 @@ """Mixin class to split KData into other subsets.""" -import copy -from typing import Literal +from typing import Literal, TypeVar, cast import torch from einops import rearrange, repeat from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol +from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits -from mrpro.utils import modify_acq_info +from mrpro.data.Rotation import Rotation + +RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation) class KDataSplitMixin(_KDataProtocol): @@ -56,8 +58,9 @@ def _split_k2_or_k1_into_other( def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: return dat_traj[:, :, :, split_idx, :] - def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: - return acq_info[:, :, split_idx, ...] + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + # cast due to https://github.com/python/mypy/issues/10817 + return cast(RotationOrTensor, acq_info[:, :, split_idx, ...]) # Rearrange other_split and k1 dimension rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0' @@ -69,8 +72,8 @@ def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: return dat_traj[:, :, split_idx, :, :] - def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: - return acq_info[:, split_idx, ...] + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + return cast(RotationOrTensor, acq_info[:, split_idx, ...]) # Rearrange other_split and k1 dimension rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0' @@ -93,13 +96,14 @@ def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj) # Create new header with correct shape - kheader = copy.deepcopy(self.header) + kheader = self.header.clone() # Update shape of acquisition info index - def reshape_acq_info(info: torch.Tensor): - return rearrange(split_acq_info(info), rearrange_pattern_acq_info) - - kheader.acq_info = modify_acq_info(reshape_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info) + if isinstance(field, Rotation | torch.Tensor) + else field + ) # Update other label limits and acquisition info setattr(kheader.encoding_limits, other_label, Limits(min=0, max=n_other - 1, center=0)) diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index b16fae37a..6cd18c2cc 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -1,21 +1,22 @@ import mrpro.utils.slice_profiles import mrpro.utils.typing +import mrpro.utils.unit_conversion from mrpro.utils.smap import smap from mrpro.utils.remove_repeat import remove_repeat from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop -from mrpro.utils.modify_acq_info import modify_acq_info from mrpro.utils.split_idx import split_idx from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view + __all__ = [ "broadcast_right", - "modify_acq_info", "reduce_view", "remove_repeat", "slice_profiles", "smap", "split_idx", "typing", + "unit_conversion", "unsqueeze_left", "unsqueeze_right", "zero_pad_or_crop" -] \ No newline at end of file +] diff --git a/src/mrpro/utils/modify_acq_info.py b/src/mrpro/utils/modify_acq_info.py deleted file mode 100644 index d535e53c5..000000000 --- a/src/mrpro/utils/modify_acq_info.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Modify AcqInfo.""" - -from __future__ import annotations - -import dataclasses -from collections.abc import Callable -from typing import TYPE_CHECKING - -import torch - -if TYPE_CHECKING: - from mrpro.data.AcqInfo import AcqInfo - - -def modify_acq_info(fun_modify: Callable, acq_info: AcqInfo) -> AcqInfo: - """Go through all fields of AcqInfo object and apply changes. - - Parameters - ---------- - fun_modify - Function which takes AcqInfo fields as input and returns modified AcqInfo field - acq_info - AcqInfo object - """ - # Apply function to all fields of acq_info - for field in dataclasses.fields(acq_info): - current = getattr(acq_info, field.name) - if isinstance(current, torch.Tensor): - setattr(acq_info, field.name, fun_modify(current)) - elif dataclasses.is_dataclass(current): - for subfield in dataclasses.fields(current): - subcurrent = getattr(current, subfield.name) - setattr(current, subfield.name, fun_modify(subcurrent)) - - return acq_info diff --git a/src/mrpro/utils/unit_conversion.py b/src/mrpro/utils/unit_conversion.py new file mode 100644 index 000000000..0115bed47 --- /dev/null +++ b/src/mrpro/utils/unit_conversion.py @@ -0,0 +1,94 @@ +"""Conversion between different units.""" + +from typing import TypeVar + +import numpy as np +import torch + +__all__ = [ + 'ms_to_s', + 's_to_ms', + 'mm_to_m', + 'm_to_mm', + 'deg_to_rad', + 'rad_to_deg', + 'lamor_frequency_to_magnetic_field', + 'magnetic_field_to_lamor_frequency', + 'GYROMAGNETIC_RATIO_PROTON', +] + +GYROMAGNETIC_RATIO_PROTON = 42.58 * 1e6 +r"""The gyromagnetic ratio :math:`\frac{\gamma}{2\pi}` of 1H in H20 in Hz/T""" + +# Conversion functions for units +T = TypeVar('T', float, torch.Tensor) + + +def ms_to_s(ms: T) -> T: + """Convert ms to s.""" + return ms / 1000 + + +def s_to_ms(s: T) -> T: + """Convert s to ms.""" + return s * 1000 + + +def mm_to_m(mm: T) -> T: + """Convert mm to m.""" + return mm / 1000 + + +def m_to_mm(m: T) -> T: + """Convert m to mm.""" + return m * 1000 + + +def deg_to_rad(deg: T) -> T: + """Convert degree to radians.""" + if isinstance(deg, torch.Tensor): + return torch.deg2rad(deg) + return deg / 180.0 * np.pi + + +def rad_to_deg(deg: T) -> T: + """Convert radians to degree.""" + if isinstance(deg, torch.Tensor): + return torch.rad2deg(deg) + return deg * 180.0 / np.pi + + +def lamor_frequency_to_magnetic_field(lamor_frequency: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON) -> T: + """Convert the Lamor frequency [Hz] to the magntic field strength [T]. + + Parameters + ---------- + lamor_frequency + Lamor frequency [Hz] + gyromagnetic_ratio + Gyromagnetic ratio [Hz/T], default: gyromagnetic ratio of 1H proton + + Returns + ------- + Magnetic field strength [T] + """ + return lamor_frequency / gyromagnetic_ratio + + +def magnetic_field_to_lamor_frequency( + magnetic_field_strength: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON +) -> T: + """Convert the magntic field strength [T] to Lamor frequency [Hz]. + + Parameters + ---------- + magnetic_field_strength + Strength of the magnetic field [T] + gyromagnetic_ratio + Gyromagnetic ratio [Hz/T], default: gyromagnetic ratio of 1H proton + + Returns + ------- + Lamor frequency [Hz] + """ + return magnetic_field_strength * gyromagnetic_ratio diff --git a/tests/conftest.py b/tests/conftest.py index e3f943462..899e8959c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,9 +45,9 @@ def generate_random_acquisition_properties(generator: RandomGenerator): 'encoding_space_ref': generator.uint16(), 'sample_time_us': generator.float32(), 'position': generator.float32_tuple(3), - 'read_dir': generator.float32_tuple(3), - 'phase_dir': generator.float32_tuple(3), - 'slice_dir': generator.float32_tuple(3), + 'read_dir': (1, 0, 0), # read, phase and slice have to form rotation + 'phase_dir': (0, 1, 0), + 'slice_dir': (0, 0, 1), 'patient_table_position': generator.float32_tuple(3), 'idx': ismrmrd.EncodingCounters(**idx_properties), 'user_int': generator.uint32_tuple(8), diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index 822c63045..b7ec1fa7d 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -2,12 +2,13 @@ import pytest import torch -from einops import rearrange, repeat +from einops import repeat from mrpro.data import KData, KTrajectory, SpatialDimension from mrpro.data.acq_filters import is_coil_calibration_acquisition +from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.traj_calculators.KTrajectoryCalculator import DummyTrajectory from mrpro.operators import FastFourierOp -from mrpro.utils import modify_acq_info, split_idx +from mrpro.utils import split_idx from tests.conftest import RandomGenerator, generate_random_data from tests.data import IsmrmrdRawTestData @@ -77,10 +78,11 @@ def consistently_shaped_kdata(request, random_kheader_shape): # Start with header kheader, n_other, n_coils, n_k2, n_k1, n_k0 = random_kheader_shape - def reshape_acq_data(data): - return rearrange(data, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1) - - kheader.acq_info = modify_acq_info(reshape_acq_data, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1 + ) + ) # Create kdata with consistent shape kdata = generate_random_data(RandomGenerator(request.param['seed']), (n_other, n_coils, n_k2, n_k1, n_k0)) @@ -162,7 +164,7 @@ def test_KData_kspace(ismrmrd_cart): assert relative_image_difference(reconstructed_img[0, 0, 0, ...], ismrmrd_cart.img_ref) <= 0.05 -@pytest.mark.parametrize(('field', 'value'), [('b0', 11.3), ('tr', torch.tensor([24.3]))]) +@pytest.mark.parametrize(('field', 'value'), [('lamor_frequency_proton', 42.88 * 1e6), ('tr', torch.tensor([24.3]))]) def test_KData_modify_header(ismrmrd_cart, field, value): """Overwrite some parameters in the header.""" parameter_dict = {field: value} @@ -469,3 +471,21 @@ def test_KData_remove_readout_os(monkeypatch, random_kheader): # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high # differences along the edges of the elliptic objects. assert relative_image_difference(torch.abs(img_recon), img_tensor[:, 0, ...]) <= 0.05 + + +def test_modify_acq_info(random_kheader_shape): + """Test the modification of the acquisition info.""" + # Create random header where AcqInfo fields are of shape [n_k1*n_k2] and reshape to [n_other, n_k2, n_k1] + kheader, n_other, _, n_k2, n_k1, _ = random_kheader_shape + + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1 + ) + ) + + # Verify shape + assert kheader.acq_info.center_sample.shape == (n_other, n_k2, n_k1, 1) + assert kheader.acq_info.idx.k1.shape == (n_other, n_k2, n_k1) + assert kheader.acq_info.orientation.shape == (n_other, n_k2, n_k1, 1) + assert kheader.acq_info.position.z.shape == (n_other, n_k2, n_k1, 1) diff --git a/tests/utils/test_modify_acq_info.py b/tests/utils/test_modify_acq_info.py deleted file mode 100644 index 451303d02..000000000 --- a/tests/utils/test_modify_acq_info.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests for modification of acquisition infos.""" - -from einops import rearrange -from mrpro.utils import modify_acq_info - - -def test_modify_acq_info(random_kheader_shape): - """Test the modification of the acquisition info.""" - # Create random header where AcqInfo fields are of shape [n_k1*n_k2] and reshape to [n_other, n_k2, n_k1] - kheader, n_other, _, n_k2, n_k1, _ = random_kheader_shape - - def reshape_acq_data(data): - return rearrange(data, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1) - - kheader.acq_info = modify_acq_info(reshape_acq_data, kheader.acq_info) - - # Verify shape - assert kheader.acq_info.center_sample.shape == (n_other, n_k2, n_k1, 1) diff --git a/tests/utils/test_unit_conversion.py b/tests/utils/test_unit_conversion.py new file mode 100644 index 000000000..a232a5366 --- /dev/null +++ b/tests/utils/test_unit_conversion.py @@ -0,0 +1,82 @@ +"""Tests of unit conversion.""" + +import numpy as np +import torch +from mrpro.utils.unit_conversion import ( + deg_to_rad, + lamor_frequency_to_magnetic_field, + m_to_mm, + magnetic_field_to_lamor_frequency, + mm_to_m, + ms_to_s, + rad_to_deg, + s_to_ms, +) + +from tests import RandomGenerator + + +def test_mm_to_m(): + """Verify mm to m conversion.""" + generator = RandomGenerator(seed=0) + mm_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(mm_to_m(mm_input), mm_input / 1000.0) + + +def test_m_to_mm(): + """Verify m to mm conversion.""" + generator = RandomGenerator(seed=0) + m_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(m_to_mm(m_input), m_input * 1000.0) + + +def test_ms_to_s(): + """Verify ms to s conversion.""" + generator = RandomGenerator(seed=0) + ms_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(ms_to_s(ms_input), ms_input / 1000.0) + + +def test_s_to_ms(): + """Verify s to ms conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(s_to_ms(s_input), s_input * 1000.0) + + +def test_rad_to_deg_tensor(): + """Verify radians to degree conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(rad_to_deg(s_input), torch.rad2deg(s_input)) + + +def test_deg_to_rad_tensor(): + """Verify degree to radians conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(deg_to_rad(s_input), torch.deg2rad(s_input)) + + +def test_rad_to_deg_float(): + """Verify radians to degree conversion.""" + assert rad_to_deg(np.pi / 2) == 90.0 + + +def test_deg_to_rad_float(): + """Verify degree to radians conversion.""" + assert deg_to_rad(180.0) == np.pi + + +def test_lamor_frequency_to_magnetic_field(): + """Verify conversion of lamor frequency to magnetic field.""" + proton_gyromagnetic_ratio = 42.58 * 1e6 + proton_lamor_frequency_at_3tesla = 127.74 * 1e6 + assert lamor_frequency_to_magnetic_field(proton_lamor_frequency_at_3tesla, proton_gyromagnetic_ratio) == 3.0 + + +def test_magnetic_field_to_lamor_frequency(): + """Verify conversion of magnetic field to lamor frequency.""" + proton_gyromagnetic_ratio = 42.58 * 1e6 + magnetic_field_strength = 3.0 + assert magnetic_field_to_lamor_frequency(magnetic_field_strength, proton_gyromagnetic_ratio) == 127.74 * 1e6 From 84b983cda06f3408c3b255866eb5c1087f7f2e48 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 12 Nov 2024 10:51:20 +0100 Subject: [PATCH 09/35] Release v0.241112 (#510) --- src/mrpro/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/VERSION b/src/mrpro/VERSION index 0f6ae6fb6..c60039027 100644 --- a/src/mrpro/VERSION +++ b/src/mrpro/VERSION @@ -1 +1 @@ -0.241029 +0.241112 From 762fcd777adca6427751a2d00015748d653416b4 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 12 Nov 2024 12:29:54 +0100 Subject: [PATCH 10/35] Select k-space data based on n_coils (#309) Co-authored-by: Felix F Zimmermann --- src/mrpro/data/_kdata/KData.py | 25 +++++++++++++++++- src/mrpro/data/acq_filters.py | 17 ++++++++++++ src/mrpro/utils/__init__.py | 1 + tests/data/_IsmrmrdRawTestData.py | 28 +++++++++++++------- tests/data/test_kdata.py | 44 ++++++++++++++++++++++++++++++- 5 files changed, 104 insertions(+), 11 deletions(-) diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index 57af617bc..d43fc49cd 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -17,7 +17,7 @@ from mrpro.data._kdata.KDataRemoveOsMixin import KDataRemoveOsMixin from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin -from mrpro.data.acq_filters import is_image_acquisition +from mrpro.data.acq_filters import has_n_coils, is_image_acquisition from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits from mrpro.data.KHeader import KHeader @@ -110,6 +110,29 @@ def from_file( modification_time = datetime.datetime.fromtimestamp(mtime) acquisitions = [acq for acq in acquisitions if acquisition_filter_criterion(acq)] + + # we need the same number of receiver coils for all acquisitions + n_coils_available = {acq.data.shape[0] for acq in acquisitions} + if len(n_coils_available) > 1: + if ( + ismrmrd_header.acquisitionSystemInformation is not None + and ismrmrd_header.acquisitionSystemInformation.receiverChannels is not None + ): + n_coils = int(ismrmrd_header.acquisitionSystemInformation.receiverChannels) + else: + # most likely, highest number of elements are the coils used for imaging + n_coils = int(max(n_coils_available)) + + warnings.warn( + f'Acquisitions with different number {n_coils_available} of receiver coil elements detected.' + 'Data with {n_coils} receiver coil elements will be used.', + stacklevel=1, + ) + acquisitions = [acq for acq in acquisitions if has_n_coils(n_coils, acq)] + + if not acquisitions: + raise ValueError('No acquisitions meeting the given filter criteria were found.') + kdata = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions]) acqinfo = AcqInfo.from_ismrmrd_acquisitions(acquisitions) diff --git a/src/mrpro/data/acq_filters.py b/src/mrpro/data/acq_filters.py index d64c4d9a6..4723d3bba 100644 --- a/src/mrpro/data/acq_filters.py +++ b/src/mrpro/data/acq_filters.py @@ -61,3 +61,20 @@ def is_coil_calibration_acquisition(acquisition: ismrmrd.Acquisition) -> bool: """ coil_calibration_flag = AcqFlags.ACQ_IS_PARALLEL_CALIBRATION | AcqFlags.ACQ_IS_PARALLEL_CALIBRATION_AND_IMAGING return coil_calibration_flag.value & acquisition.flags + + +def has_n_coils(n_coils: int, acquisition: ismrmrd.Acquisition) -> bool: + """Test if acquisitions was obtained with a certain number of receiver coils. + + Parameters + ---------- + n_coils + number of receiver coils + acquisition + ISMRMRD acquisition + + Returns + ------- + True if the acquisition was obtained with n_coils receiver coils + """ + return acquisition.data.shape[0] == n_coils diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 6cd18c2cc..80ef9d398 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -6,6 +6,7 @@ from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop from mrpro.utils.split_idx import split_idx from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view +import mrpro.utils.unit_conversion __all__ = [ "broadcast_right", diff --git a/tests/data/_IsmrmrdRawTestData.py b/tests/data/_IsmrmrdRawTestData.py index 59c42be15..181be56d6 100644 --- a/tests/data/_IsmrmrdRawTestData.py +++ b/tests/data/_IsmrmrdRawTestData.py @@ -67,6 +67,7 @@ def __init__( trajectory_type: Literal['cartesian', 'radial'] = 'cartesian', sampling_order: Literal['linear', 'low_high', 'high_low', 'random'] = 'linear', phantom: EllipsePhantom | None = None, + add_bodycoil_acquisitions: bool = False, n_separate_calibration_lines: int = 0, ): if not phantom: @@ -222,23 +223,32 @@ def __init__( acq.phase_dir[1] = 1.0 acq.slice_dir[2] = 1.0 - # Initialize an acquisition counter - counter = 0 + scan_counter = 0 # Write out a few noise scans for _ in range(32): noise = self.noise_level * torch.randn(self.n_coils, n_freq_encoding, dtype=torch.complex64) # here's where we would make the noise correlated - acq.scan_counter = counter + acq.scan_counter = scan_counter acq.clearAllFlags() acq.setFlag(ismrmrd.ACQ_IS_NOISE_MEASUREMENT) acq.data[:] = noise.numpy() dataset.append_acquisition(acq) - counter += 1 # increment the scan counter + scan_counter += 1 + + # Add acquisitions obtained with a 2-element body coil (e.g. used for adjustment scans) + if add_bodycoil_acquisitions: + acq.resize(n_freq_encoding, 2, trajectory_dimensions=2) + for _ in range(8): + acq.scan_counter = scan_counter + acq.clearAllFlags() + acq.data[:] = torch.randn(2, n_freq_encoding, dtype=torch.complex64) + dataset.append_acquisition(acq) + scan_counter += 1 + acq.resize(n_freq_encoding, self.n_coils, trajectory_dimensions=2) # Calibration lines if n_separate_calibration_lines > 0: - # we take calibration lines around the k-space center traj_ky_calibration, traj_kx_calibration, kpe_calibration = self._cartesian_trajectory( n_separate_calibration_lines, n_freq_encoding, @@ -253,7 +263,7 @@ def __init__( for pe_idx, pe_pos in enumerate(kpe_calibration): # Set some fields in the header - acq.scan_counter = counter + acq.scan_counter = scan_counter # kpe is in the range [-npe//2, npe//2), the ismrmrd kspace_encoding_step_1 is in the range [0, npe) kspace_encoding_step_1 = pe_pos + n_phase_encoding // 2 @@ -264,7 +274,7 @@ def __init__( # Set the data and append acq.data[:] = kspace_calibration[:, :, pe_idx].numpy() dataset.append_acquisition(acq) - counter += 1 + scan_counter += 1 # Loop over the repetitions, add noise and write to disk for rep in range(self.repetitions): @@ -275,7 +285,7 @@ def __init__( for pe_idx, pe_pos in enumerate(kpe[rep]): if not self.flag_invalid_reps or rep == 0 or pe_idx < len(kpe[rep]) // 2: # fewer lines for rep > 0 # Set some fields in the header - acq.scan_counter = counter + acq.scan_counter = scan_counter # kpe is in the range [-npe//2, npe//2), the ismrmrd kspace_encoding_step_1 is in the range [0, npe) kspace_encoding_step_1 = pe_pos + n_phase_encoding // 2 @@ -298,7 +308,7 @@ def __init__( # Set the data and append acq.data[:] = kspace_with_noise[:, :, pe_idx].numpy() dataset.append_acquisition(acq) - counter += 1 + scan_counter += 1 # Clean up dataset.close() diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index b7ec1fa7d..ab4f5aabb 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -4,7 +4,7 @@ import torch from einops import repeat from mrpro.data import KData, KTrajectory, SpatialDimension -from mrpro.data.acq_filters import is_coil_calibration_acquisition +from mrpro.data.acq_filters import has_n_coils, is_coil_calibration_acquisition, is_image_acquisition from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.traj_calculators.KTrajectoryCalculator import DummyTrajectory from mrpro.operators import FastFourierOp @@ -29,6 +29,20 @@ def ismrmrd_cart(ellipse_phantom, tmp_path_factory): return ismrmrd_kdata +@pytest.fixture(scope='session') +def ismrmrd_cart_bodycoil_and_surface_coil(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set with bodycoil and surface coil data.""" + ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' + ismrmrd_kdata = IsmrmrdRawTestData( + filename=ismrmrd_filename, + noise_level=0.0, + repetitions=3, + phantom=ellipse_phantom.phantom, + add_bodycoil_acquisitions=True, + ) + return ismrmrd_kdata + + @pytest.fixture(scope='session') def ismrmrd_cart_with_calibration_lines(ellipse_phantom, tmp_path_factory): """Undersampled Cartesian data set with calibration lines.""" @@ -126,6 +140,34 @@ def test_KData_raise_wrong_trajectory_shape(ismrmrd_cart): _ = KData.from_file(ismrmrd_cart.filename, trajectory) +def test_KData_raise_warning_for_bodycoil(ismrmrd_cart_bodycoil_and_surface_coil): + """Mix of bodycoil and surface coil acquisitions leads to warning.""" + with pytest.raises(UserWarning, match='Acquisitions with different number'): + _ = KData.from_file(ismrmrd_cart_bodycoil_and_surface_coil.filename, DummyTrajectory()) + + +@pytest.mark.filterwarnings('ignore:Acquisitions with different number:UserWarning') +def test_KData_select_bodycoil_via_filter(ismrmrd_cart_bodycoil_and_surface_coil): + """Bodycoil can be selected via a custom acquisition filter.""" + # This is the recommended way of selecting the body coil (i.e. 2 receiver elements) + kdata = KData.from_file( + ismrmrd_cart_bodycoil_and_surface_coil.filename, + DummyTrajectory(), + acquisition_filter_criterion=lambda acq: has_n_coils(2, acq) and is_image_acquisition(acq), + ) + assert kdata.data.shape[-4] == 2 + + +def test_KData_raise_wrong_coil_number(ismrmrd_cart): + """Wrong number of coils leads to empty acquisitions.""" + with pytest.raises(ValueError, match='No acquisitions meeting the given filter criteria were found'): + _ = KData.from_file( + ismrmrd_cart.filename, + DummyTrajectory(), + acquisition_filter_criterion=lambda acq: has_n_coils(2, acq) and is_image_acquisition(acq), + ) + + def test_KData_from_file_diff_nky_for_rep(ismrmrd_cart_invalid_reps): """Multiple repetitions with different number of phase encoding lines.""" with pytest.warns(UserWarning, match=r'different number'): From 455679547c70e55aff86c40c0b1ee1ab98742e90 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 12 Nov 2024 13:01:04 +0100 Subject: [PATCH 11/35] Fix formatting in warning string (#514) --- src/mrpro/data/_kdata/KData.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index d43fc49cd..aaf430497 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -124,8 +124,8 @@ def from_file( n_coils = int(max(n_coils_available)) warnings.warn( - f'Acquisitions with different number {n_coils_available} of receiver coil elements detected.' - 'Data with {n_coils} receiver coil elements will be used.', + f'Acquisitions with different number {n_coils_available} of receiver coil elements detected. ' + f'Data with {n_coils} receiver coil elements will be used.', stacklevel=1, ) acquisitions = [acq for acq in acquisitions if has_n_coils(n_coils, acq)] From a8214130948447f27e1a1c540b366f865acfae5d Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 12 Nov 2024 13:05:32 +0100 Subject: [PATCH 12/35] Add human readable ids to test (#485) --- tests/algorithms/test_cg.py | 3 +- tests/conftest.py | 300 +++++++++--------- tests/data/test_rotation.py | 8 +- tests/data/test_trajectory.py | 20 +- tests/operators/functionals/conftest.py | 6 +- tests/operators/models/conftest.py | 20 +- tests/operators/test_cartesian_sampling_op.py | 16 +- tests/operators/test_fourier_op.py | 36 ++- tests/operators/test_rearrangeop.py | 1 + 9 files changed, 217 insertions(+), 193 deletions(-) diff --git a/tests/algorithms/test_cg.py b/tests/algorithms/test_cg.py index 8a4434e2a..4abda2a98 100644 --- a/tests/algorithms/test_cg.py +++ b/tests/algorithms/test_cg.py @@ -16,7 +16,8 @@ (1, 32, False), (4, 32, True), (4, 32, False), - ] + ], + ids=['complex_single', 'real_single', 'complex_batch', 'real_batch'], ) def system(request): """Generate data for creating a system Hx=b with linear and self-adjoint diff --git a/tests/conftest.py b/tests/conftest.py index 899e8959c..1fb7fb95f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -233,16 +233,16 @@ def create_uniform_traj(nk, k_shape): return k -def create_traj(k_shape, nkx, nky, nkz, sx, sy, sz): +def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): """Create trajectory with random entries.""" random_generator = RandomGenerator(seed=0) k_list = [] - for spacing, nk in zip([sz, sy, sx], [nkz, nky, nkx], strict=True): - if spacing == 'nuf': + for spacing, nk in zip([type_kz, type_ky, type_kx], [nkz, nky, nkx], strict=True): + if spacing == 'non-uniform': k = random_generator.float32_tensor(size=nk) - elif spacing == 'uf': + elif spacing == 'uniform': k = create_uniform_traj(nk, k_shape=k_shape) - elif spacing == 'z': + elif spacing == 'zero': k = torch.zeros(nk) k_list.append(k) trajectory = KTrajectory(k_list[0], k_list[1], k_list[2], repeat_detection_tolerance=None) @@ -250,161 +250,163 @@ def create_traj(k_shape, nkx, nky, nkz, sx, sy, sz): COMMON_MR_TRAJECTORIES = pytest.mark.parametrize( - ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz', 's0', 's1', 's2'), + ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz', 'type_k0', 'type_k1', 'type_k2'), [ - # (0) 2d cart mri with 1 coil, no oversampling - ( - (1, 1, 1, 96, 128), # img shape - (1, 1, 1, 96, 128), # k shape - (1, 1, 1, 128), # kx - (1, 1, 96, 1), # ky - (1, 1, 1, 1), # kz - 'uf', # kx is uniform - 'uf', # ky is uniform - 'z', # zero so no Fourier transform is performed along that dimension - 'uf', # k0 is uniform - 'uf', # k1 is uniform - 'z', # k2 is singleton + ( # (0) 2d Cartesian single coil, no oversampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 96, 128), # k_shape + (1, 1, 1, 128), # nkx + (1, 1, 96, 1), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (1) 2d cart mri with 1 coil, with oversampling - ( - (1, 1, 1, 96, 128), - (1, 1, 1, 128, 192), - (1, 1, 1, 192), - (1, 1, 128, 1), - (1, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'uf', - 'z', + ( # (1) 2d Cartesian single coil, with oversampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 128, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 128, 1), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (2) 2d non-Cartesian mri with 2 coils - ( - (1, 2, 1, 96, 128), - (1, 2, 1, 16, 192), - (1, 1, 16, 192), - (1, 1, 16, 192), - (1, 1, 1, 1), - 'nuf', # kx is non-uniform - 'nuf', - 'z', - 'nuf', - 'nuf', - 'z', + ( # (2) 2d non-Cartesian mri with 2 coils + (1, 2, 1, 96, 128), # im_shape + (1, 2, 1, 16, 192), # k_shape + (1, 1, 16, 192), # nkx + (1, 1, 16, 192), # nky + (1, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'zero', # type_k2 ), - # (3) 2d cart mri with irregular sampling - ( - (1, 1, 1, 96, 128), - (1, 1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'z', - 'z', + ( # (3) 2d Cartesian with irregular sampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 1, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 1, 192), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'zero', # type_k1 + 'zero', # type_k2 ), - # (4) 2d single shot spiral - ( - (1, 2, 1, 96, 128), - (1, 1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 1), - 'nuf', - 'nuf', - 'z', - 'nuf', - 'z', - 'z', + ( # (4) 2d single shot spiral + (1, 2, 1, 96, 128), # im_shape + (1, 1, 1, 1, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 1, 192), # nky + (1, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'zero', # type_k1 + 'zero', # type_k2 ), - # (5) 3d nuFFT mri, 4 coils, 2 other - ( - (2, 4, 16, 32, 64), - (2, 4, 16, 32, 64), - (2, 16, 32, 64), - (2, 16, 32, 64), - (2, 16, 32, 64), - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', + ( # (5) 3d non-uniform, 4 coils, 2 other + (2, 4, 16, 32, 64), # im_shape + (2, 4, 16, 32, 64), # k_shape + (2, 16, 32, 64), # nkx + (2, 16, 32, 64), # nky + (2, 16, 32, 64), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (6) 2d nuFFT cine mri with 8 cardiac phases, 5 coils - ( - (8, 5, 1, 64, 64), - (8, 5, 1, 18, 128), - (8, 1, 18, 128), - (8, 1, 18, 128), - (8, 1, 1, 1), - 'nuf', - 'nuf', - 'z', - 'nuf', - 'nuf', - 'z', + ( # (6) 2d non-uniform cine with 8 cardiac phases, 5 coils + (8, 5, 1, 64, 64), # im_shape + (8, 5, 1, 18, 128), # k_shape + (8, 1, 18, 128), # nkx + (8, 1, 18, 128), # nky + (8, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'zero', # type_k2 ), - # (7) 2d cart cine mri with 9 cardiac phases, 6 coils - ( - (9, 6, 1, 96, 128), - (9, 6, 1, 128, 192), - (9, 1, 1, 192), - (9, 1, 128, 1), - (9, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'uf', - 'z', + ( # (7) 2d cartesian cine with 9 cardiac phases, 6 coils + (9, 6, 1, 96, 128), # im_shape + (9, 6, 1, 128, 192), # k_shape + (9, 1, 1, 192), # nkx + (9, 1, 128, 1), # nky + (9, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (8) radial phase encoding (RPE), 8 coils, with oversampling in both FFT and nuFFT directions - ( - (2, 8, 64, 32, 48), - (2, 8, 8, 64, 96), - (2, 1, 1, 96), - (2, 8, 64, 1), - (2, 8, 64, 1), - 'uf', - 'nuf', - 'nuf', - 'uf', - 'nuf', - 'nuf', + ( # (8) radial phase encoding (RPE), 8 coils, with oversampling in both FFT and non-uniform directions + (2, 8, 64, 32, 48), # im_shape + (2, 8, 8, 64, 96), # k_shape + (2, 1, 1, 96), # nkx + (2, 8, 64, 1), # nky + (2, 8, 64, 1), # nkz + 'uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (9) radial phase encoding (RPE) , 8 coils with non-Cartesian sampling along readout - ( - (2, 8, 64, 32, 48), - (2, 8, 8, 64, 96), - (2, 1, 1, 96), - (2, 8, 64, 1), - (2, 8, 64, 1), - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', + ( # (9) radial phase encoding (RPE), 8 coils with non-Cartesian sampling along readout + (2, 8, 64, 32, 48), # im_shape + (2, 8, 8, 64, 96), # k_shape + (2, 1, 1, 96), # nkx + (2, 8, 64, 1), # nky + (2, 8, 64, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (10) stack of stars, 5 other, 3 coil, oversampling in both FFT and nuFFT directions - ( - (5, 3, 48, 16, 32), - (5, 3, 96, 18, 64), - (5, 1, 18, 64), - (5, 1, 18, 64), - (5, 96, 1, 1), - 'nuf', - 'nuf', - 'uf', - 'nuf', - 'nuf', - 'uf', + ( # (10) stack of stars, 5 other, 3 coil, oversampling in both FFT and non-uniform directions + (5, 3, 48, 16, 32), # im_shape + (5, 3, 96, 18, 64), # k_shape + (5, 1, 18, 64), # nkx + (5, 1, 18, 64), # nky + (5, 96, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'uniform', # type_k2 ), ], + ids=[ + '2d_cartesian_1_coil_no_oversampling', + '2d_cartesian_1_coil_with_oversampling', + '2d_non_cartesian_mri_2_coils', + '2d_cartesian_irregular_sampling', + '2d_single_shot_spiral', + '3d_nonuniform_4_coils_2_other', + '2d_nnonuniform_cine_mri_8_cardiac_phases_5_coils', + '2d_cartesian_cine_9_cardiac_phases_6_coils', + 'radial_phase_encoding_8_coils_with_oversampling', + 'radial_phase_encoding_8_coils_non_cartesian_sampling', + 'stack_of_stars_5_other_3_coil_with_oversampling', + ], ) diff --git a/tests/data/test_rotation.py b/tests/data/test_rotation.py index 6b4cbb52c..035a2d6c6 100644 --- a/tests/data/test_rotation.py +++ b/tests/data/test_rotation.py @@ -535,7 +535,7 @@ def _test_stats(error: torch.Tensor, mean_max: float, rms_max: float) -> None: assert torch.all(rms < rms_max) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): rnd = RandomGenerator(0) @@ -555,7 +555,7 @@ def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): _test_stats(angles_quat - angles, 1e-15, 1e-14) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_symmetric_axes(seq_tuple, intrinsic): rnd = RandomGenerator(0) @@ -576,7 +576,7 @@ def test_as_euler_symmetric_axes(seq_tuple, intrinsic): _test_stats(angles_quat - angles, 1e-16, 1e-14) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic): # Since we cannot check for angle equality, we check for rotation matrix @@ -598,7 +598,7 @@ def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic): torch.testing.assert_close(mat_expected, mat_estimated) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_degenerate_symmetric_axes(seq_tuple, intrinsic): # Since we cannot check for angle equality, we check for rotation matrix diff --git a/tests/data/test_trajectory.py b/tests/data/test_trajectory.py index 1baf4340b..1061a93be 100644 --- a/tests/data/test_trajectory.py +++ b/tests/data/test_trajectory.py @@ -147,16 +147,16 @@ def test_trajectory_cpu(cartesian_grid): @COMMON_MR_TRAJECTORIES -def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): """Test identification of traj types.""" # Generate random k-space trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Find out the type of the kz, ky and kz dimensions - single_value_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'z'] - on_grid_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'uf'] - not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'nuf'] + single_value_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'z'] + on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'uf'] + not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'nuf'] # check dimensions which are of shape 1 and do not need any transform assert all(trajectory.type_along_kzyx[dim] & TrajType.SINGLEVALUE for dim in single_value_dims) @@ -171,16 +171,16 @@ def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, @COMMON_MR_TRAJECTORIES -def test_ktype_along_k210(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_ktype_along_k210(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): """Test identification of traj types.""" # Generate random k-space trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Find out the type of the k2, k1 and k0 dimensions - single_value_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'z'] - on_grid_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'uf'] - not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'nuf'] + single_value_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'z'] + on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'uf'] + not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'nuf'] # check dimensions which are of shape 1 and do not need any transform assert all(trajectory.type_along_k210[dim] & TrajType.SINGLEVALUE for dim in single_value_dims) diff --git a/tests/operators/functionals/conftest.py b/tests/operators/functionals/conftest.py index 9d2bfd8c9..ae8ffc20b 100644 --- a/tests/operators/functionals/conftest.py +++ b/tests/operators/functionals/conftest.py @@ -47,12 +47,12 @@ def result_dtype(self): def functional_test_cases(func: Callable[[FunctionalTestCase], None]) -> Callable[..., None]: """Decorator combining multiple parameterizations for test cases for all proximable functionals.""" - @pytest.mark.parametrize('shape', [[1, 2, 3]]) + @pytest.mark.parametrize('shape', [[1, 2, 3]], ids=['shape=[1,2,3]']) @pytest.mark.parametrize('dtype_name', ['float32', 'complex64']) @pytest.mark.parametrize('weight', ['scalar_weight', 'tensor_weight', 'complex_weight']) @pytest.mark.parametrize('target', ['no_target', 'random_target']) - @pytest.mark.parametrize('dim', [None]) - @pytest.mark.parametrize('divide_by_n', [True, False]) + @pytest.mark.parametrize('dim', [None], ids=['dim=None']) + @pytest.mark.parametrize('divide_by_n', [True, False], ids=['mean', 'sum']) @pytest.mark.parametrize('functional', PROXIMABLE_FUNCTIONALS) def wrapper( functional: type[ElementaryProximableFunctional], diff --git a/tests/operators/models/conftest.py b/tests/operators/models/conftest.py index 570fa2f1f..75fceacd2 100644 --- a/tests/operators/models/conftest.py +++ b/tests/operators/models/conftest.py @@ -8,7 +8,7 @@ SHAPE_VARIATIONS_SIGNAL_MODELS = pytest.mark.parametrize( ('parameter_shape', 'contrast_dim_shape', 'signal_shape'), [ - ((1, 1, 10, 20, 30), (5,), (5, 1, 1, 10, 20, 30)), # single map with different inversion times + ((1, 1, 10, 20, 30), (5,), (5, 1, 1, 10, 20, 30)), # single map with different contrast times ((1, 1, 10, 20, 30), (5, 1), (5, 1, 1, 10, 20, 30)), ((4, 1, 1, 10, 20, 30), (5, 1), (5, 4, 1, 1, 10, 20, 30)), # multiple maps along additional batch dimension ((4, 1, 1, 10, 20, 30), (5,), (5, 4, 1, 1, 10, 20, 30)), @@ -25,6 +25,24 @@ ((1,), (5,), (5, 1)), # single voxel ((4, 3, 1), (5, 4, 3), (5, 4, 3, 1)), ], + ids=[ + 'single_map_diff_contrast_times', + 'single_map_diff_contrast_times_2', + 'multiple_maps_additional_batch_dim', + 'multiple_maps_additional_batch_dim_2', + 'multiple_maps_additional_batch_dim_3', + 'multiple_maps_other_dim', + 'multiple_maps_other_dim_2', + 'multiple_maps_other_dim_3', + 'multiple_maps_other_and_batch_dim', + 'multiple_maps_other_and_batch_dim_2', + 'multiple_maps_other_and_batch_dim_3', + 'multiple_maps_other_and_batch_dim_4', + 'multiple_maps_other_and_batch_dim_5', + 'different_value_each_voxel', + 'single_voxel', + 'multiple_voxels', + ], ) diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 7caa13e7c..28c6e8860 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -17,10 +17,10 @@ def test_cart_sampling_op_data_match(): nkx = (1, 1, 1, 60) nky = (1, 1, 40, 1) nkz = (1, 20, 1, 1) - sx = 'uf' - sy = 'uf' - sz = 'uf' - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + type_kx = 'uniform' + type_ky = 'uniform' + type_kz = 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Create matching data random_generator = RandomGenerator(seed=0) @@ -73,10 +73,10 @@ def test_cart_sampling_op_fwd_adj(sampling): nkx = (2, 1, 1, 60) nky = (2, 1, 40, 1) nkz = (2, 20, 1, 1) - sx = 'uf' - sy = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' - sz = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' - trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor() + type_kx = 'uniform' + type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz).as_tensor() # Subsample data and trajectory match sampling: diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 89a4bbc11..6f8c377cf 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -9,22 +9,24 @@ from tests.helper import dotproduct_adjointness_test -def create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): random_generator = RandomGenerator(seed=0) # generate random image img = random_generator.complex64_tensor(size=im_shape) # create random trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) return img, trajectory @COMMON_MR_TRAJECTORIES -def test_fourier_fwd_adj_property(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_fourier_op_fwd_adj_property( + im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2 +): """Test adjoint property of Fourier operator.""" # generate random images and k-space trajectories - img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz) + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) @@ -42,26 +44,26 @@ def test_fourier_fwd_adj_property(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, @pytest.mark.parametrize( - ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz'), + ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz'), # parameter names [ - # Cartesian FFT dimensions are not aligned with corresponding k2, k1, k0 dimensions - ( - (5, 3, 48, 16, 32), - (5, 3, 96, 18, 64), - (5, 1, 18, 64), - (5, 96, 1, 1), # Cartesian ky dimension defined along k2 rather than k1 - (5, 1, 18, 64), - 'nuf', - 'uf', - 'nuf', + ( # Cartesian FFT dimensions are not aligned with corresponding k2, k1, k0 dimensions + (5, 3, 48, 16, 32), # im_shape + (5, 3, 96, 18, 64), # k_shape + (5, 1, 18, 64), # nkx + (5, 96, 1, 1), # nky - Cartesian ky dimension defined along k2 rather than k1 + (5, 1, 18, 64), # nkz + 'non-uniform', # type_kx + 'uniform', # type_ky + 'non-uniform', # type_kz ), ], + ids=['cartesian_fft_dims_not_aligned_with_k2_k1_k0_dims'], ) -def test_fourier_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): """Test trajectory not supported by Fourier operator.""" # generate random images and k-space trajectories - img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz) + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) diff --git a/tests/operators/test_rearrangeop.py b/tests/operators/test_rearrangeop.py index 3db888897..054c99402 100644 --- a/tests/operators/test_rearrangeop.py +++ b/tests/operators/test_rearrangeop.py @@ -15,6 +15,7 @@ ((2, 2, 4), '... a b->... (a b)', (2, 8), {'b': 4}), # flatten ((2), '... (a b) -> ... a b', (2, 1), {'b': 1}), # unflatten ], + ids=['swap_axes', 'flatten', 'unflatten'], ) def test_einsum_op(input_shape, rule, output_shape, additional_info, dtype): """Test adjointness and shape.""" From 72c18070d45bbad969bf2244b51163b10f5b627a Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 12 Nov 2024 15:31:57 +0100 Subject: [PATCH 13/35] Add CartesianSamplingOp to FourierOp (#482) Co-authored-by: Felix F Zimmermann --- src/mrpro/operators/FourierOp.py | 41 +++++++++++++++++++----------- tests/conftest.py | 14 ++++++++++ tests/data/test_kdata.py | 13 ---------- tests/operators/test_fourier_op.py | 29 +++++++++++++++++++-- 4 files changed, 67 insertions(+), 30 deletions(-) diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index f4254d8d2..c57d1fbfd 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -11,6 +11,7 @@ from mrpro.data.enums import TrajType from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp from mrpro.operators.FastFourierOp import FastFourierOp from mrpro.operators.LinearOperator import LinearOperator @@ -67,12 +68,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): self._nufft_dims.append(dim) if self._fft_dims: - self._fast_fourier_op = FastFourierOp( + self._fast_fourier_op: FastFourierOp | None = FastFourierOp( dim=tuple(self._fft_dims), recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims), encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims), ) - + self._cart_sampling_op: CartesianSamplingOp | None = CartesianSamplingOp( + encoding_matrix=encoding_matrix, traj=traj + ) + else: + self._fast_fourier_op = None + self._cart_sampling_op = None # Find dimensions which require NUFFT if self._nufft_dims: fft_dims_k210 = [ @@ -102,20 +108,23 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega] self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction - self._fwd_nufft_op = KbNufft( + self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - self._adj_nufft_op = KbNufftAdjoint( + self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - - self._kshape = traj.broadcasted_shape + else: + self._omega: torch.Tensor | None = None + self._fwd_nufft_op = None + self._adj_nufft_op = None + self._kshape = traj.broadcasted_shape @classmethod def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self: @@ -146,11 +155,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil k-space data with shape: (... coils k2 k1 k0) """ - if len(self._fft_dims): - # FFT - (x,) = self._fast_fourier_op(x) - - if self._nufft_dims: + if self._fwd_nufft_op is not None and self._omega is not None: + # NUFFT Type 2 # we need to move the nufft-dimensions to the end and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils nufft_dims # we could move the permute to __init__ but then we still would need to prepend if len(other)>1 @@ -163,7 +169,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.flatten(end_dim=-len(keep_dims) - 1) # omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims) - # TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape). omega = self._omega.permute(*permute) omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :]) omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) @@ -173,6 +178,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims] x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils x = x.permute(*unpermute) + + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: + # FFT + (x,) = self._cart_sampling_op(self._fast_fourier_op(x)[0]) + return (x,) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: @@ -187,11 +197,12 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil image data with shape: (... coils z y x) """ - if self._fft_dims: + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: # IFFT - (x,) = self._fast_fourier_op.adjoint(x) + (x,) = self._fast_fourier_op.adjoint(self._cart_sampling_op.adjoint(x)[0]) - if self._nufft_dims: + if self._adj_nufft_op is not None and self._omega is not None: + # NUFFT Type 1 # we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils (nufft_dims) keep_dims = [-4, *self._nufft_dims] # -4 is coil diff --git a/tests/conftest.py b/tests/conftest.py index 1fb7fb95f..3bd8946f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from xsdata.models.datatype import XmlDate, XmlTime from tests import RandomGenerator +from tests.data import IsmrmrdRawTestData from tests.phantoms import EllipsePhantomTestData @@ -249,6 +250,19 @@ def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): return trajectory +@pytest.fixture(scope='session') +def ismrmrd_cart(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set.""" + ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' + ismrmrd_kdata = IsmrmrdRawTestData( + filename=ismrmrd_filename, + noise_level=0.0, + repetitions=3, + phantom=ellipse_phantom.phantom, + ) + return ismrmrd_kdata + + COMMON_MR_TRAJECTORIES = pytest.mark.parametrize( ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz', 'type_k0', 'type_k1', 'type_k2'), [ diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index ab4f5aabb..d5cfa0f0c 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -16,19 +16,6 @@ from tests.phantoms import EllipsePhantomTestData -@pytest.fixture(scope='session') -def ismrmrd_cart(ellipse_phantom, tmp_path_factory): - """Fully sampled cartesian data set.""" - ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' - ismrmrd_kdata = IsmrmrdRawTestData( - filename=ismrmrd_filename, - noise_level=0.0, - repetitions=3, - phantom=ellipse_phantom.phantom, - ) - return ismrmrd_kdata - - @pytest.fixture(scope='session') def ismrmrd_cart_bodycoil_and_surface_coil(ellipse_phantom, tmp_path_factory): """Fully sampled cartesian data set with bodycoil and surface coil data.""" diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 6f8c377cf..f48a24260 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -1,7 +1,9 @@ """Tests for Fourier operator.""" import pytest -from mrpro.data import SpatialDimension +import torch +from mrpro.data import KData, KTrajectory, SpatialDimension +from mrpro.data.traj_calculators import KTrajectoryCartesian from mrpro.operators import FourierOp from tests import RandomGenerator @@ -30,7 +32,11 @@ def test_fourier_op_fwd_adj_property( # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) # apply forward operator @@ -70,3 +76,22 @@ def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'): FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) + + +def test_fourier_op_cartesian_sorting(ismrmrd_cart): + """Verify correct sorting of Cartesian k-space data before FFT.""" + kdata = KData.from_file(ismrmrd_cart.filename, KTrajectoryCartesian()) + ff_op = FourierOp.from_kdata(kdata) + (img,) = ff_op.adjoint(kdata.data) + + # shuffle the kspace points along k0 + permutation_index = torch.randperm(kdata.data.shape[-1]) + kdata_unsorted = KData( + header=kdata.header, + data=kdata.data[..., permutation_index], + traj=KTrajectory.from_tensor(kdata.traj.as_tensor()[..., permutation_index]), + ) + ff_op_unsorted = FourierOp.from_kdata(kdata_unsorted) + (img_unsorted,) = ff_op_unsorted.adjoint(kdata_unsorted.data) + + torch.testing.assert_close(img, img_unsorted) From 202d395a2c10f2c3db7c491851988fd26dca14a5 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 12 Nov 2024 17:17:18 +0100 Subject: [PATCH 14/35] Exclude data outside of encoding_matrix (#234) Co-authored-by: Felix Zimmermann --- src/mrpro/operators/CartesianSamplingOp.py | 98 ++++++++++++++++--- tests/conftest.py | 2 +- tests/operators/test_cartesian_sampling_op.py | 25 +++++ tests/operators/test_fourier_op.py | 6 +- 4 files changed, 115 insertions(+), 16 deletions(-) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 7a51924b1..07f8aba65 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -1,5 +1,7 @@ """Cartesian Sampling Operator.""" +import warnings + import torch from einops import rearrange, repeat @@ -7,6 +9,7 @@ from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension from mrpro.operators.LinearOperator import LinearOperator +from mrpro.utils.reshape import unsqueeze_left class CartesianSamplingOp(LinearOperator): @@ -64,10 +67,35 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> # 1D indices into a flattened tensor. kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)') + + # check that all points are inside the encoding matrix + inside_encoding_matrix = ( + ((kx_idx >= 0) & (kx_idx < sorted_grid_shape.x)) + & ((ky_idx >= 0) & (ky_idx < sorted_grid_shape.y)) + & ((kz_idx >= 0) & (kz_idx < sorted_grid_shape.z)) + ) + if not torch.all(inside_encoding_matrix): + warnings.warn( + 'K-space points lie outside of the encoding_matrix and will be ignored.' + ' Increase the encoding_matrix to include these points.', + stacklevel=2, + ) + + inside_encoding_matrix = rearrange(inside_encoding_matrix, '... kz ky kx -> ... 1 (kz ky kx)') + inside_encoding_matrix_idx = inside_encoding_matrix.nonzero(as_tuple=True)[-1] + inside_encoding_matrix_idx = torch.reshape(inside_encoding_matrix_idx, (*kidx.shape[:-1], -1)) + self.register_buffer('_inside_encoding_matrix_idx', inside_encoding_matrix_idx) + kidx = torch.take_along_dim(kidx, inside_encoding_matrix_idx, dim=-1) + else: + self._inside_encoding_matrix_idx: torch.Tensor | None = None + self.register_buffer('_fft_idx', kidx) + # we can skip the indexing if the data is already sorted self._needs_indexing = ( - not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + not torch.all(torch.diff(kidx) == 1) + or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + or self._inside_encoding_matrix_idx is not None ) self._trajectory_shape = traj.broadcasted_shape @@ -93,8 +121,21 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: return (x,) x_kflat = rearrange(x, '... coil k2_enc k1_enc k0_enc -> ... coil (k2_enc k1_enc k0_enc)') - # take_along_dim does broadcast, so no need for extending here - x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1) + # take_along_dim broadcasts, but needs the same number of dimensions + idx = unsqueeze_left(self._fft_idx, x_kflat.ndim - self._fft_idx.ndim) + x_inside_encoding_matrix = torch.take_along_dim(x_kflat, idx, dim=-1) + + if self._inside_encoding_matrix_idx is None: + # all trajectory points are inside the encoding matrix + x_indexed = x_inside_encoding_matrix + else: + # we need to add zeros + x_indexed = self._broadcast_and_scatter_along_last_dim( + x_inside_encoding_matrix, + self._trajectory_shape[-1] * self._trajectory_shape[-2] * self._trajectory_shape[-3], + self._inside_encoding_matrix_idx, + ) + # reshape to (... other coil, k2, k1, k0) x_reshaped = x_indexed.reshape(x.shape[:-3] + self._trajectory_shape[-3:]) @@ -120,18 +161,13 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)') - # scatter does not broadcast, so we need to manually broadcast the indices - broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1]) - idx_expanded = torch.broadcast_to(self._fft_idx, (*broadcast_shape, self._fft_idx.shape[-1])) + if self._inside_encoding_matrix_idx is not None: + idx = unsqueeze_left(self._inside_encoding_matrix_idx, y_kflat.ndim - self._inside_encoding_matrix_idx.ndim) + y_kflat = torch.take_along_dim(y_kflat, idx, dim=-1) - # although scatter_ is inplace, this will not cause issues with autograd, as self - # is always constant zero and gradients w.r.t. src work as expected. - y_scattered = torch.zeros( - *broadcast_shape, - self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, - dtype=y.dtype, - device=y.device, - ).scatter_(dim=-1, index=idx_expanded, src=y_kflat) + y_scattered = self._broadcast_and_scatter_along_last_dim( + y_kflat, self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, self._fft_idx + ) # reshape to ..., other, coil, k2_enc, k1_enc, k0_enc y_reshaped = y_scattered.reshape( @@ -142,3 +178,37 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: ) return (y_reshaped,) + + @staticmethod + def _broadcast_and_scatter_along_last_dim( + data_to_scatter: torch.Tensor, n_last_dim: int, scatter_index: torch.Tensor + ) -> torch.Tensor: + """Broadcast scatter index and scatter into zero tensor. + + Parameters + ---------- + data_to_scatter + Data to be scattered at indices scatter_index + n_last_dim + Number of data points in last dimension + scatter_index + Indices describing where to scatter data + + Returns + ------- + Data scattered into tensor along scatter_index + """ + # scatter does not broadcast, so we need to manually broadcast the indices + broadcast_shape = torch.broadcast_shapes(scatter_index.shape[:-1], data_to_scatter.shape[:-1]) + idx_expanded = torch.broadcast_to(scatter_index, (*broadcast_shape, scatter_index.shape[-1])) + + # although scatter_ is inplace, this will not cause issues with autograd, as self + # is always constant zero and gradients w.r.t. src work as expected. + data_scattered = torch.zeros( + *broadcast_shape, + n_last_dim, + dtype=data_to_scatter.dtype, + device=data_to_scatter.device, + ).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter) + + return data_scattered diff --git a/tests/conftest.py b/tests/conftest.py index 3bd8946f2..30ae9c229 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -240,7 +240,7 @@ def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): k_list = [] for spacing, nk in zip([type_kz, type_ky, type_kx], [nkz, nky, nkx], strict=True): if spacing == 'non-uniform': - k = random_generator.float32_tensor(size=nk) + k = random_generator.float32_tensor(size=nk, low=-1, high=1) * max(nk) elif spacing == 'uniform': k = create_uniform_traj(nk, k_shape=k_shape) elif spacing == 'zero': diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 28c6e8860..c1738b7bb 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -118,3 +118,28 @@ def test_cart_sampling_op_fwd_adj(sampling): u = random_generator.complex64_tensor(size=k_shape) v = random_generator.complex64_tensor(size=k_shape[:2] + trajectory.as_tensor().shape[2:]) dotproduct_adjointness_test(sampling_op, u, v) + + +@pytest.mark.parametrize(('k2_min', 'k2_max'), [(-1, 21), (-21, 1)]) +@pytest.mark.parametrize(('k0_min', 'k0_max'), [(-6, 13), (-13, 6)]) +def test_cart_sampling_op_oversampling(k0_min, k0_max, k2_min, k2_max): + """Test trajectory points outside of encoding_matrix.""" + encoding_matrix = SpatialDimension(40, 1, 20) + + # Create kx and kz sampling which are asymmetric and larger than the encoding matrix on one side + # The indices are inverted to ensure CartesianSamplingOp acts on them + kx = rearrange(torch.linspace(k0_max, k0_min, 20), 'kx->1 1 1 kx') + ky = torch.ones(1, 1, 1, 1) + kz = rearrange(torch.linspace(k2_max, k2_min, 40), 'kz-> kz 1 1') + kz = torch.stack([kz, -kz], dim=0) # different kz values for two other elements + trajectory = KTrajectory(kz=kz, ky=ky, kx=kx) + + with pytest.warns(UserWarning, match='K-space points lie outside of the encoding_matrix'): + sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) + + random_generator = RandomGenerator(seed=0) + u = random_generator.complex64_tensor(size=(3, 2, 5, kz.shape[-3], ky.shape[-2], kx.shape[-1])) + v = random_generator.complex64_tensor(size=(3, 2, 5, *encoding_matrix.zyx)) + + assert sampling_op.adjoint(u)[0].shape[-3:] == encoding_matrix.zyx + assert sampling_op(v)[0].shape[-3:] == (kz.shape[-3], ky.shape[-2], kx.shape[-1]) diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index f48a24260..2d76642c3 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -73,7 +73,11 @@ def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'): FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) From 798f1e26d4736127fe697319b3354e0b7f703105 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Wed, 13 Nov 2024 10:02:25 +0100 Subject: [PATCH 15/35] Switch to adjoint as backward in FourierOP (#516) --- src/mrpro/operators/FourierOp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index c57d1fbfd..cacdda1dc 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -16,7 +16,7 @@ from mrpro.operators.LinearOperator import LinearOperator -class FourierOp(LinearOperator): +class FourierOp(LinearOperator, adjoint_as_backward=True): """Fourier Operator class.""" def __init__( From 37ae8b9e1bed3f04e3f9bd4f8f5228c294c10bde Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Wed, 13 Nov 2024 22:47:01 +0100 Subject: [PATCH 16/35] Make randperm reproducible in tests (#517) --- tests/_RandomGenerator.py | 72 ++++++++++++++----- tests/data/_IsmrmrdRawTestData.py | 4 +- tests/operators/test_cartesian_sampling_op.py | 8 +-- tests/operators/test_fourier_op.py | 2 +- tests/utils/test_split_idx.py | 4 +- 5 files changed, 66 insertions(+), 24 deletions(-) diff --git a/tests/_RandomGenerator.py b/tests/_RandomGenerator.py index 9248271e9..1d829b55a 100644 --- a/tests/_RandomGenerator.py +++ b/tests/_RandomGenerator.py @@ -27,11 +27,12 @@ class RandomGenerator: """ def __init__(self, seed): + """Initialize with a fixed seed.""" self.generator = torch.Generator().manual_seed(seed) @staticmethod def _clip_bounds(low, high, lowest, highest): - """Clips the bounds (low, high) to the given range (lowest, highest)""" + """Clips the bounds (low, high) to the given range (lowest, highest).""" if low > high: raise ValueError('low should be lower than high') low = max(low, lowest) @@ -54,24 +55,25 @@ def _dtype_bounds(dtype): return (info.min, info.max) def _randint(self, size, low, high, dtype=torch.int64) -> torch.Tensor: - """Generates uniform random integers in the range [low, high) with the - given dtype.""" + """Generate uniform random integers in [low, high) with given dtype.""" low, high = self._clip_bounds(low, high, *self._dtype_bounds(dtype)) return torch.randint(low, high, size, generator=self.generator, dtype=dtype) def _rand(self, size, low, high, dtype=torch.float32) -> torch.Tensor: - """Generates uniform random floats in the range [low, high) with the - given dtype.""" + """Generate uniform random floats in [low, high) with given dtype.""" low, high = self._clip_bounds(low, high, *self._dtype_bounds(dtype)) return (torch.rand(size, generator=self.generator, dtype=dtype) * (high - low)) + low def float32_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0) -> torch.Tensor: + """Generate float32 tensor of given size in [low, high).""" return self._rand(size, low, high, torch.float32) def float64_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0) -> torch.Tensor: + """Generate float64 tensor of given size in [low, high).""" return self._rand(size, low, high, torch.float64) def complex64_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0) -> torch.Tensor: + """Generate complex64 tensor of given size in [low, high).""" if low < 0: raise ValueError('low/high refer to the amplitude and must be positive') amp = self.float32_tensor(size, low, high) @@ -79,6 +81,7 @@ def complex64_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, h return (amp * torch.exp(1j * phase)).to(dtype=torch.complex64) def complex128_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, high: float = 1.0) -> torch.Tensor: + """Generate complex128 tensor of given size in [low, high).""" if low < 0: raise ValueError('low/high refer to the amplitude and must be positive') amp = self.float64_tensor(size, low, high) @@ -86,15 +89,19 @@ def complex128_tensor(self, size: Sequence[int] | int = (1,), low: float = 0.0, return (amp * torch.exp(1j * phase)).to(dtype=torch.complex128) def int8_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 7, high: int = 1 << 7) -> torch.Tensor: + """Generate int8 tensor of given size in [low, high).""" return self._randint(size, low, high, dtype=torch.int8) def int16_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 15, high: int = 1 << 15) -> torch.Tensor: + """Generate int16 tensor of given size in [low, high).""" return self._randint(size, low, high, dtype=torch.int16) def int32_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 31, high: int = 1 << 31) -> torch.Tensor: + """Generate int32 tensor of given size in [low, high).""" return self._randint(size, low, high, dtype=torch.int32) def int64_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 63, high: int = 1 << 63) -> torch.Tensor: + """Generate int64 tensor of given size in [low, high).""" return self._randint(size, low, high, dtype=torch.int64) # There is no uint32 in pytorch yet @@ -106,106 +113,133 @@ def int64_tensor(self, size: Sequence[int] | int = (1,), low: int = -1 << 63, hi # return self._randint(size, low, high, dtype=torch.uint64) # noqa: ERA001 def uint8_tensor(self, size: Sequence[int] | int = (1,), low: int = 0, high: int = 1 << 8) -> torch.Tensor: + """Generate uint8 tensor of given size in [low, high).""" return self._randint(size, low, high, dtype=torch.uint8) def bool(self) -> bool: + """Generate a random boolean value.""" return self.uint8(0, 1) == 1 def float32(self, low: float = 0.0, high: float = 1.0) -> float: + """Generate a float32 in [low, high).""" return self.float32_tensor((1,), low, high).item() def float64(self, low: float = 0.0, high: float = 1.0) -> float: + """Generate a float64 in [low, high).""" return self.float64_tensor((1,), low, high).item() def complex64(self, low: float = 0, high: float = 1.0) -> complex: + """Generate a complex64 in [low, high).""" return self.complex64_tensor((1,), low, high).item() def complex128(self, low: float = 0, high: float = 1.0) -> complex: + """Generate a complex128 in [low, high).""" return self.complex128_tensor((1,), low, high).item() def uint8(self, low: int = 0, high: int = 1 << 8) -> int: + """Generate a uint8 in [low, high).""" return int(self.uint8_tensor((1,), low, high).item()) def uint16(self, low: int = 0, high: int = 1 << 16) -> int: + """Generate a uint16 in [low, high).""" + if low < 0 or high > 1 << 16: + raise ValueError('Low must be positive and high must be <= 2^16') + # using int32 as it is the smallest that can hold 2^16 (no uint32 in pytorch) return int(self.int32_tensor((1,), low, high).item()) def uint32(self, low: int = 0, high: int = 1 << 32) -> int: - # using int64 to avoid overflow + """Generate a uint32 in [low, high).""" + if low < 0 or high > 1 << 32: + raise ValueError('Low must be positive and high must be <= 2^32') + # using int64 as it is the smallest that can hold 2^32 (no uint64 in pytorch) return int(self.int64_tensor((1,), low, high).item()) def int8(self, low: int = -1 << 7, high: int = 1 << 7 - 1) -> int: + """Generate an int8 in [low, high).""" return int(self.int8_tensor((1,), low, high).item()) def int16(self, low: int = -1 << 15, high: int = 1 << 15 - 1) -> int: + """Generate an int16 in [low, high).""" return int(self.int16_tensor((1,), low, high).item()) def int32(self, low: int = -1 << 31, high: int = 1 << 31 - 1) -> int: + """Generate an int32 in [low, high).""" return int(self.int32_tensor((1,), low, high).item()) def int64(self, low: int = -1 << 63, high: int = 1 << 63 - 1) -> int: + """Generate an int64 in [low, high).""" return int(self.int64_tensor((1,), low, high).item()) def uint64(self, low: int = 0, high: int = 1 << 64) -> int: - # pytorch does not support uint64, so we use int64 instead - # and then convert to uint64 + """Generate a uint64 in [low, high).""" + if low < 0 or high > 1 << 64: + raise ValueError('Low must be positive and high must be <= 2^64') + # no uint64 in pytorch. int64 would not be able to produce 2^64, + # so we need to shift the values from [-2^63, 2^63) to [0, 2^64) range_ = high - low - if low < 0: - raise ValueError('Low must be positive') - if range_ > 1 << 64: - raise ValueError('Range too large') new_low = -1 << 63 new_high = new_low + range_ value = self.int64(new_low, new_high) - new_low + low return value def float32_tuple(self, size: int, low: float = 0, high: float = 1) -> tuple[float, ...]: + """Generate a tuple of float32 of given size in [low, high).""" return tuple(self.float32_tensor((size,), low, high)) def float64_tuple(self, size: int, low: float = 0, high: float = 1) -> tuple[float, ...]: + """Generate a tuple of float64 of given size in [low, high).""" return tuple(self.float64_tensor((size,), low, high)) def complex64_tuple(self, size: int, low: float = 0, high: float = 1) -> tuple[complex, ...]: + """Generate a tuple of complex64 of given size in [low, high).""" return tuple(self.complex64_tensor((size,), low, high)) def complex128_tuple(self, size: int, low: float = 0, high: float = 1) -> tuple[complex, ...]: + """Generate a tuple of complex128 of given size in [low, high).""" return tuple(self.complex128_tensor((size,), low, high)) def uint8_tuple(self, size: int, low: int = 0, high: int = 1 << 8) -> tuple[int, ...]: + """Generate a tuple of uint8 of given size in [low, high).""" return tuple(self.uint8_tensor((size,), low, high)) def uint16_tuple(self, size: int, low: int = 0, high: int = 1 << 16) -> tuple[int, ...]: - # no uint16_tensor, so we use uint16 instead + """Generate a tuple of uint16 of given size in [low, high).""" return tuple([self.uint16(low, high) for _ in range(size)]) def uint32_tuple(self, size: int, low: int = 0, high: int = 1 << 32) -> tuple[int, ...]: - # no uint32_tensor, so we use uint32 instead + """Generate a tuple of uint32 of given size in [low, high).""" return tuple([self.uint32(low, high) for _ in range(size)]) def uint64_tuple(self, size: int, low: int = 0, high: int = 1 << 64) -> tuple[int, ...]: - # no uint64_tensor, so we use uint64 instead + """Generate a tuple of uint64 of given size in [low, high).""" return tuple([self.uint64(low, high) for _ in range(size)]) def int8_tuple(self, size: int, low: int = -1 << 7, high: int = 1 << 7) -> tuple[int, ...]: + """Generate a tuple of int8 of given size in [low, high).""" return tuple(self.int8_tensor((size,), low, high)) def int16_tuple(self, size: int, low: int = -1 << 15, high: int = 1 << 15) -> tuple[int, ...]: + """Generate a tuple of int16 of given size in [low, high).""" return tuple(self.int16_tensor((size,), low, high)) def int32_tuple(self, size: int, low: int = -1 << 31, high: int = 1 << 31) -> tuple[int, ...]: + """Generate a tuple of int32 of given size in [low, high).""" return tuple(self.int32_tensor((size,), low, high)) def int64_tuple(self, size: int, low: int = -1 << 63, high: int = 1 << 63) -> tuple[int, ...]: + """Generate a tuple of int64 of given size in [low, high).""" return tuple(self.int64_tensor((size,), low, high)) def ascii(self, size: int) -> str: + """Generate a random ASCII string of given size.""" return ''.join([chr(self.uint8(32, 127)) for _ in range(size)]) def rand_like(self, x: torch.Tensor, low=0.0, high=1.0) -> torch.Tensor: - """Generate a tensor with the same shape as x filled with uniform random numbers in [low , high).""" + """Generate tensor like x with uniform random numbers in [low, high).""" return self.rand_tensor(x.shape, x.dtype, low=low, high=high) def rand_tensor(self, shape: Sequence[int], dtype: torch.dtype, low: float, high: float) -> torch.Tensor: - """Generates a tensor with the given shape and dtype filled with uniform random numbers in [low , high).""" + """Generate tensor of given shape and dtype in [low, high).""" if dtype.is_complex: tensor = self.complex64_tensor(shape, low, high).to(dtype=dtype) elif dtype.is_floating_point: @@ -215,3 +249,7 @@ def rand_tensor(self, shape: Sequence[int], dtype: torch.dtype, low: float, high else: tensor = self._randint(shape, low, high, dtype) return tensor + + def randperm(self, n, *, dtype=torch.int64) -> torch.Tensor: + """Generate random permutation of integers from 0 to n-1.""" + return torch.randperm(n, generator=self.generator, dtype=dtype) diff --git a/tests/data/_IsmrmrdRawTestData.py b/tests/data/_IsmrmrdRawTestData.py index 181be56d6..efeff6ed1 100644 --- a/tests/data/_IsmrmrdRawTestData.py +++ b/tests/data/_IsmrmrdRawTestData.py @@ -10,6 +10,8 @@ from mrpro.data import SpatialDimension from mrpro.phantoms import EllipsePhantom +from tests import RandomGenerator + ISMRMRD_TRAJECTORY_TYPE = ( 'cartesian', 'epi', @@ -347,7 +349,7 @@ def _cartesian_trajectory( # Different temporal orders of phase encoding points if sampling_order == 'random': - perm = torch.randperm(len(kpe)) + perm = RandomGenerator(13).randperm(len(kpe)) kpe = kpe[perm[: len(perm) // acceleration]] elif sampling_order == 'linear': kpe, _ = torch.sort(kpe) diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index c1738b7bb..5bff62dab 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -81,7 +81,7 @@ def test_cart_sampling_op_fwd_adj(sampling): # Subsample data and trajectory match sampling: case 'random': - random_idx = torch.randperm(k_shape[-2]) + random_idx = RandomGenerator(13).randperm(k_shape[-2]) trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx, :]) case 'partial_echo': trajectory = KTrajectory.from_tensor(trajectory_tensor[..., : k_shape[-1] // 2]) @@ -90,11 +90,11 @@ def test_cart_sampling_op_fwd_adj(sampling): case 'regular_undersampling': trajectory = KTrajectory.from_tensor(trajectory_tensor[..., ::3, ::5, :]) case 'random_undersampling': - random_idx = torch.randperm(k_shape[-2]) + random_idx = RandomGenerator(13).randperm(k_shape[-2]) trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: k_shape[-2] // 2], :]) case 'different_random_undersampling': traj_list = [ - traj_one_other[..., torch.randperm(k_shape[-2])[: k_shape[-2] // 2], :] + traj_one_other[..., RandomGenerator(13).randperm(k_shape[-2])[: k_shape[-2] // 2], :] for traj_one_other in trajectory_tensor.unbind(1) ] trajectory = KTrajectory.from_tensor(torch.stack(traj_list, dim=1)) @@ -105,7 +105,7 @@ def test_cart_sampling_op_fwd_adj(sampling): trajectory = KTrajectory.from_tensor(trajectory_tensor) case 'kx_ky_along_k0_undersampling': trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') - random_idx = torch.randperm(trajectory_tensor.shape[-1]) + random_idx = RandomGenerator(13).randperm(trajectory_tensor.shape[-1]) trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]]) case _: raise NotImplementedError(f'Test {sampling} not implemented.') diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 2d76642c3..5eccbbc1b 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -89,7 +89,7 @@ def test_fourier_op_cartesian_sorting(ismrmrd_cart): (img,) = ff_op.adjoint(kdata.data) # shuffle the kspace points along k0 - permutation_index = torch.randperm(kdata.data.shape[-1]) + permutation_index = RandomGenerator(13).randperm(kdata.data.shape[-1]) kdata_unsorted = KData( header=kdata.header, data=kdata.data[..., permutation_index], diff --git a/tests/utils/test_split_idx.py b/tests/utils/test_split_idx.py index 6997fc1bc..30501b9ac 100644 --- a/tests/utils/test_split_idx.py +++ b/tests/utils/test_split_idx.py @@ -5,6 +5,8 @@ from einops import repeat from mrpro.utils import split_idx +from tests import RandomGenerator + @pytest.mark.parametrize( ('ni_per_block', 'ni_overlap', 'cyclic', 'unique_values_in_last_block'), @@ -19,7 +21,7 @@ def test_split_idx(ni_per_block, ni_overlap, cyclic, unique_values_in_last_block # Create a regular sequence of values vals = repeat(torch.tensor([0, 1, 2, 3]), 'd0 -> (d0 repeat)', repeat=5) # Mix up values - vals = vals[torch.randperm(vals.shape[0])] + vals = vals[RandomGenerator(13).randperm(vals.shape[0])] # Split indices of sorted sequence idx_split = split_idx(torch.argsort(vals), ni_per_block, ni_overlap, cyclic) From f4b7f4ef95007bd9bc96b9c65ab9826cb363f25c Mon Sep 17 00:00:00 2001 From: Stefan Martin <67423672+Stef-Martin@users.noreply.github.com> Date: Thu, 14 Nov 2024 01:26:14 +0100 Subject: [PATCH 17/35] Fix operator norm device (#520) Fixes #512 --- src/mrpro/operators/LinearOperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/operators/LinearOperator.py b/src/mrpro/operators/LinearOperator.py index d919e63c6..029089f5d 100644 --- a/src/mrpro/operators/LinearOperator.py +++ b/src/mrpro/operators/LinearOperator.py @@ -163,7 +163,7 @@ def operator_norm( # operator norm is a strictly positive number. This ensures that the first time the # change between the old and the new estimate of the operator norm is non-zero and # thus prevents the loop from exiting despite a non-correct estimate. - op_norm_old = torch.zeros(*tuple([1 for _ in range(vector.ndim)])) + op_norm_old = torch.zeros(*tuple([1 for _ in range(vector.ndim)]), device=vector.device) dim = tuple(dim) if dim is not None else dim for _ in range(max_iterations): From 85ab6e7500080fc0ec55e3d6313631686d48cbbf Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Thu, 14 Nov 2024 11:06:08 +0100 Subject: [PATCH 18/35] Add coil compression using PCA (#470) Co-authored-by: Felix Zimmermann --- src/mrpro/data/_kdata/KData.py | 92 +++++++++++++++++++++++++++++++++- tests/data/test_kdata.py | 56 +++++++++++++++++++++ 2 files changed, 147 insertions(+), 1 deletion(-) diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index aaf430497..4b5df6250 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -3,8 +3,9 @@ import dataclasses import datetime import warnings -from collections.abc import Callable +from collections.abc import Callable, Sequence from pathlib import Path +from types import EllipsisType import h5py import ismrmrd @@ -276,3 +277,92 @@ def __repr__(self): f'{self.header}' ) return out + + def compress_coils( + self: Self, + n_compressed_coils: int, + batch_dims: None | Sequence[int] = None, + joint_dims: Sequence[int] | EllipsisType = ..., + ) -> Self: + """Reduce the number of coils based on a PCA compression. + + A PCA is carried out along the coil dimension and the n_compressed_coils virtual coil elements are selected. For + more information on coil compression please see [BUE2007]_, [DON2008]_ and [HUA2008]_. + + Returns a copy of the data. + + Parameters + ---------- + kdata + K-space data + n_compressed_coils + Number of compressed coils + batch_dims + Dimensions which are treated as batched, i.e. separate coil compression matrizes (e.g. different slices). + Default is to do one coil compression matrix for the entire k-space data. Only batch_dim or joint_dim can + be defined. If batch_dims is not None then joint_dims has to be ... + joint_dims + Dimensions which are combined to calculate single coil compression matrix (e.g. k0, k1, contrast). Default + is that all dimensions (except for the coil dimension) are joint_dims. Only batch_dim or joint_dim can + be defined. If joint_dims is not ... batch_dims has to be None + + Returns + ------- + Copy of K-space data with compressed coils. + + Raises + ------ + ValueError + If both batch_dims and joint_dims are defined. + Valuer Error + If coil dimension is part of joint_dims or batch_dims. + + References + ---------- + .. [BUE2007] Buehrer M, Pruessmann KP, Boesiger P, Kozerke S (2007) Array compression for MRI with large coil + arrays. MRM 57. https://doi.org/10.1002/mrm.21237 + .. [DON2008] Doneva M, Boernert P (2008) Automatic coil selection for channel reduction in SENSE-based parallel + imaging. MAGMA 21. https://doi.org/10.1007/s10334-008-0110-x + .. [HUA2008] Huang F, Vijayakumar S, Li Y, Hertel S, Duensing GR (2008) A software channel compression + technique for faster reconstruction with many channels. MRM 26. https://doi.org/10.1016/j.mri.2007.04.010 + + """ + from mrpro.operators import PCACompressionOp + + coil_dim = -4 % self.data.ndim + if batch_dims is not None and joint_dims is not Ellipsis: + raise ValueError('Either batch_dims or joint_dims can be defined not both.') + + if joint_dims is not Ellipsis: + joint_dims_normalized = [i % self.data.ndim for i in joint_dims] + if coil_dim in joint_dims_normalized: + raise ValueError('Coil dimension must not be in joint_dims') + batch_dims_normalized = [ + d for d in range(self.data.ndim) if d not in joint_dims_normalized and d is not coil_dim + ] + else: + batch_dims_normalized = [] if batch_dims is None else [i % self.data.ndim for i in batch_dims] + if coil_dim in batch_dims_normalized: + raise ValueError('Coil dimension must not be in batch_dims') + + # reshape to (*batch dimension, -1, coils) + permute_order = ( + batch_dims_normalized + + [i for i in range(self.data.ndim) if i != coil_dim and i not in batch_dims_normalized] + + [coil_dim] + ) + kdata_coil_compressed = self.data.permute(permute_order) + permuted_kdata_shape = kdata_coil_compressed.shape + kdata_coil_compressed = kdata_coil_compressed.flatten( + start_dim=len(batch_dims_normalized), end_dim=-2 + ) # keep separate dimensions and coil + + pca_compression_op = PCACompressionOp(data=kdata_coil_compressed, n_components=n_compressed_coils) + (kdata_coil_compressed,) = pca_compression_op(kdata_coil_compressed) + + # reshape to original dimensions and undo permutation + kdata_coil_compressed = torch.reshape( + kdata_coil_compressed, [*permuted_kdata_shape[:-1], n_compressed_coils] + ).permute(*np.argsort(permute_order)) + + return type(self)(self.header.clone(), kdata_coil_compressed, self.traj.clone()) diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index d5cfa0f0c..afb6cb977 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -518,3 +518,59 @@ def test_modify_acq_info(random_kheader_shape): assert kheader.acq_info.idx.k1.shape == (n_other, n_k2, n_k1) assert kheader.acq_info.orientation.shape == (n_other, n_k2, n_k1, 1) assert kheader.acq_info.position.z.shape == (n_other, n_k2, n_k1, 1) + + +def test_KData_compress_coils(ismrmrd_cart): + """Test coil combination does not alter image content (much).""" + kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) + kdata = kdata.compress_coils(n_compressed_coils=4) + ff_op = FastFourierOp(dim=(-1, -2)) + (reconstructed_img,) = ff_op.adjoint(kdata.data) + + # Image content of each coil is the same. Therefore we only compare one coil image but we need to normalize. + reconstructed_img = reconstructed_img[0, 0, 0, ...].abs() + reconstructed_img /= reconstructed_img.max() + ref_img = ismrmrd_cart.img_ref[0, 0, 0, ...].abs() + ref_img /= ref_img.max() + + assert relative_image_difference(reconstructed_img, ref_img) <= 0.1 + + +@pytest.mark.parametrize( + ('batch_dims', 'joint_dims'), + [ + (None, ...), + ((0,), ...), + ((-2, -1), ...), + (None, (-1, -2, -3)), + (None, (0, -1, -2, -3)), + ], + ids=[ + 'single_compression', + 'batching_along_dim0', + 'batching_along_dim-2_and_dim-1', + 'single_compression_for_last_3_dims', + 'single_compression_for_last_3_and_first_dims', + ], +) +def test_KData_compress_coils_diff_batch_joint_dims(consistently_shaped_kdata, batch_dims, joint_dims): + """Test that all of these options work and yield the same shape.""" + n_compressed_coils = 4 + orig_kdata_shape = consistently_shaped_kdata.data.shape + kdata = consistently_shaped_kdata.compress_coils(n_compressed_coils, batch_dims, joint_dims) + assert kdata.data.shape == (*orig_kdata_shape[:-4], n_compressed_coils, *orig_kdata_shape[-3:]) + + +def test_KData_compress_coils_error_both_batch_and_joint(consistently_shaped_kdata): + """Test if error is raised if both batch_dims and joint_dims is defined.""" + with pytest.raises(ValueError, match='Either batch_dims or joint_dims'): + consistently_shaped_kdata.compress_coils(n_compressed_coils=3, batch_dims=(0,), joint_dims=(0,)) + + +def test_KData_compress_coils_error_coil_dim(consistently_shaped_kdata): + """Test if error is raised if coil_dim is in batch_dims or joint_dims.""" + with pytest.raises(ValueError, match='Coil dimension must not'): + consistently_shaped_kdata.compress_coils(n_compressed_coils=3, batch_dims=(-4,)) + + with pytest.raises(ValueError, match='Coil dimension must not'): + consistently_shaped_kdata.compress_coils(n_compressed_coils=3, joint_dims=(-4,)) From d652c9da73e786907b119a8e66a270124895754b Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Thu, 14 Nov 2024 11:41:11 +0100 Subject: [PATCH 19/35] Fix typing issues (#515) --- src/mrpro/algorithms/optimizers/cg.py | 4 ++-- src/mrpro/phantoms/coils.py | 9 ++++----- tests/operators/models/conftest.py | 4 +++- .../test_transient_steady_state_with_preparation.py | 6 +++--- tests/operators/models/test_wasabi.py | 8 +++++--- tests/operators/models/test_wasabiti.py | 8 +++++--- 6 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/mrpro/algorithms/optimizers/cg.py b/src/mrpro/algorithms/optimizers/cg.py index b879796b0..2d458bfa0 100644 --- a/src/mrpro/algorithms/optimizers/cg.py +++ b/src/mrpro/algorithms/optimizers/cg.py @@ -82,7 +82,7 @@ def cg( return solution # dummy value. new value will be set in loop before first usage - residual_norm_squared_previous = None + residual_norm_squared_previous: torch.Tensor | None = None for iteration in range(max_iterations): # calculate the square norm of the residual @@ -93,7 +93,7 @@ def cg( if tolerance != 0 and (residual_norm_squared < tolerance**2): return solution - if iteration > 0: + if residual_norm_squared_previous is not None: # not first iteration beta = residual_norm_squared / residual_norm_squared_previous conjugate_vector = residual + beta * conjugate_vector diff --git a/src/mrpro/phantoms/coils.py b/src/mrpro/phantoms/coils.py index e936900de..dd9c208fe 100644 --- a/src/mrpro/phantoms/coils.py +++ b/src/mrpro/phantoms/coils.py @@ -1,6 +1,5 @@ """Numerical coil simulations.""" -import numpy as np import torch from einops import repeat @@ -45,16 +44,16 @@ def birdcage_2d( y_co = repeat(y_co, 'y x -> coils y x', coils=1) c = repeat(torch.linspace(0, dim[0] - 1, dim[0]), 'coils -> coils y x', y=1, x=1) - coil_center_x = dim[2] * relative_radius * np.cos(c * (2 * torch.pi / dim[0])) - coil_center_y = dim[1] * relative_radius * np.sin(c * (2 * torch.pi / dim[0])) + coil_center_x = dim[2] * relative_radius * torch.cos(c * (2 * torch.pi / dim[0])) + coil_center_y = dim[1] * relative_radius * torch.sin(c * (2 * torch.pi / dim[0])) coil_phase = -c * (2 * torch.pi / dim[0]) rr = torch.sqrt((x_co - coil_center_x) ** 2 + (y_co - coil_center_y) ** 2) phi = torch.arctan2((x_co - coil_center_x), -(y_co - coil_center_y)) + coil_phase - sensitivities = (1 / rr) * np.exp(1j * phi) + sensitivities = (1 / rr) * torch.exp(1j * phi) if normalize_with_rss: - rss = torch.sqrt(torch.sum(torch.abs(sensitivities) ** 2, 0)) + rss = sensitivities.abs().square().sum(0).sqrt() # Normalize only where rss is > 0 sensitivities[:, rss > 0] /= rss[None, rss > 0] diff --git a/tests/operators/models/conftest.py b/tests/operators/models/conftest.py index 75fceacd2..4aab81ae0 100644 --- a/tests/operators/models/conftest.py +++ b/tests/operators/models/conftest.py @@ -46,7 +46,9 @@ ) -def create_parameter_tensor_tuples(parameter_shape=(10, 5, 100, 100, 100), number_of_tensors=2): +def create_parameter_tensor_tuples( + parameter_shape=(10, 5, 100, 100, 100), number_of_tensors=2 +) -> tuple[torch.Tensor, ...]: """Create tuples of tensors as input to operators.""" random_generator = RandomGenerator(seed=0) parameter_tensors = random_generator.float32_tensor(size=(number_of_tensors, *parameter_shape), low=1e-10) diff --git a/tests/operators/models/test_transient_steady_state_with_preparation.py b/tests/operators/models/test_transient_steady_state_with_preparation.py index 2b9d2e451..553a8dbcc 100644 --- a/tests/operators/models/test_transient_steady_state_with_preparation.py +++ b/tests/operators/models/test_transient_steady_state_with_preparation.py @@ -66,9 +66,9 @@ def test_transient_steady_state_shape(parameter_shape, contrast_dim_shape, signa """Test correct signal shapes.""" (sampling_time,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) if len(parameter_shape) == 1: - repetition_time = 5 - m0_scaling_preparation = 1 - delay_after_preparation = 0.01 + repetition_time: float | torch.Tensor = 5 + m0_scaling_preparation: float | torch.Tensor = 1 + delay_after_preparation: float | torch.Tensor = 0.01 else: repetition_time, m0_scaling_preparation, delay_after_preparation = create_parameter_tensor_tuples( contrast_dim_shape[1:], number_of_tensors=3 diff --git a/tests/operators/models/test_wasabi.py b/tests/operators/models/test_wasabi.py index 6b4494539..3d5c9e312 100644 --- a/tests/operators/models/test_wasabi.py +++ b/tests/operators/models/test_wasabi.py @@ -5,7 +5,9 @@ from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples -def create_data(offset_max=500, n_offsets=101, b0_shift=0, rb1=1.0, c=1.0, d=2.0): +def create_data( + offset_max=500, n_offsets=101, b0_shift=0, rb1=1.0, c=1.0, d=2.0 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: offsets = torch.linspace(-offset_max, offset_max, n_offsets) return offsets, torch.Tensor([b0_shift]), torch.Tensor([rb1]), torch.Tensor([c]), torch.Tensor([d]) @@ -20,8 +22,8 @@ def test_WASABI_shift(): wasabi_model = WASABI(offsets=offsets_shifted) (signal_shifted,) = wasabi_model(b0_shift, rb1, c, d) - lower_index = (offsets_shifted == -300).nonzero()[0][0].item() - upper_index = (offsets_shifted == 500).nonzero()[0][0].item() + lower_index = int((offsets_shifted == -300).nonzero()[0][0]) + upper_index = int((offsets_shifted == 500).nonzero()[0][0]) assert signal[0] == signal[-1], 'Result should be symmetric around center' assert signal_shifted[lower_index] == signal_shifted[upper_index], 'Result should be symmetric around shift' diff --git a/tests/operators/models/test_wasabiti.py b/tests/operators/models/test_wasabiti.py index de2cacc50..eeedf9463 100644 --- a/tests/operators/models/test_wasabiti.py +++ b/tests/operators/models/test_wasabiti.py @@ -6,7 +6,9 @@ from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples -def create_data(offset_max=500, n_offsets=101, b0_shift=0, rb1=1.0, t1=1.0): +def create_data( + offset_max=500, n_offsets=101, b0_shift=0, rb1=1.0, t1=1.0 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: offsets = torch.linspace(-offset_max, offset_max, n_offsets) return offsets, torch.Tensor([b0_shift]), torch.Tensor([rb1]), torch.Tensor([t1]) @@ -28,8 +30,8 @@ def test_WASABITI_symmetry_after_shift(): wasabiti_model = WASABITI(offsets=offsets_shifted, trec=trec) (signal_shifted,) = wasabiti_model(b0_shift, rb1, t1) - lower_index = (offsets_shifted == -300).nonzero()[0][0].item() - upper_index = (offsets_shifted == 500).nonzero()[0][0].item() + lower_index = int((offsets_shifted == -300).nonzero()[0][0]) + upper_index = int((offsets_shifted == 500).nonzero()[0][0]) assert signal_shifted[lower_index] == signal_shifted[upper_index], 'Result should be symmetric around shift' From e42f6665c63f62a643a5b964b6458dabc7b22507 Mon Sep 17 00:00:00 2001 From: Lunin Leonid Date: Thu, 14 Nov 2024 12:47:41 +0100 Subject: [PATCH 20/35] Switch to myst-nb for.ipynb in docs Co-authored-by: Felix F Zimmermann --- .github/workflows/docs.yml | 40 ++++++-------------------------------- .gitignore | 1 + docs/source/conf.py | 12 ++++++++++++ docs/source/examples.rst | 4 ++++ pyproject.toml | 7 ++++++- 5 files changed, 29 insertions(+), 35 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index bb861a8d2..cc5ff5458 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -120,14 +120,14 @@ jobs: run: | notebook=${{ matrix.notebook }} echo "ARTIFACT_NAME=${notebook/.ipynb/}" >> $GITHUB_OUTPUT - echo "HTML_RESULT=${notebook/.ipynb/.html}" >> $GITHUB_OUTPUT + echo "IPYNB_EXECUTED=${notebook}" >> $GITHUB_OUTPUT - name: Upload notebook uses: actions/upload-artifact@v4 if: always() with: name: ${{ steps.artifact_names.outputs.ARTIFACT_NAME }} - path: ${{ github.workspace }}/nb-runner.out/${{ steps.artifact_names.outputs.HTML_RESULT }} + path: ${{ github.workspace }}/nb-runner.out/${{ steps.artifact_names.outputs.IPYNB_EXECUTED }} env: RUNNER: ${{ toJson(runner) }} @@ -150,39 +150,11 @@ jobs: - name: Install mrpro and dependencies run: pip install --upgrade --upgrade-strategy "eager" .[docs] - - name: Download notebook html files + - name: Download executed notebook ipynb files id: download uses: actions/download-artifact@v4 with: - path: ./docs/source/notebook_artifact/ - - - name: Copy notebook html files - run: | - mkdir ./docs/source/_notebooks - cd ./docs/source/notebook_artifact/ - notebooks=$(grep -rl --include='*' './') - for nb in $notebooks - do - echo "current jupyter-notebook: $nb" - cp ./$nb ../_notebooks/ - done - - - name: List of notebooks - run: | - cd ./docs/source/_notebooks/ - notebooks=$(grep -rl --include='*.html' './') - cd ../ - echo "" >> examples.rst - for nb in $notebooks - do - echo " notebook_${nb/.html/.rst}" >> examples.rst - notebook_description=$(grep '

\(.*\) "notebook_${nb/.html/.rst}" - echo "========" >> "notebook_${nb/.html/.rst}" - echo ".. raw:: html" >> "notebook_${nb/.html/.rst}" - echo " :file: ./_notebooks/$nb" >> "notebook_${nb/.html/.rst}" - done + path: ./docs/source/_notebooks/ - name: Build docs run: | @@ -195,7 +167,7 @@ jobs: with: name: Documentation path: docs/build/html/ - + - run: echo 'Artifact url ${{ steps.save_docu.outputs.artifact-url }}' - run: echo 'Event number ${{ github.event.number }}' @@ -225,7 +197,7 @@ jobs: deploy: if: github.ref == 'refs/heads/main' permissions: - pages: write + pages: write id-token: write environment: name: github-pages diff --git a/.gitignore b/.gitignore index c1694cad9..29cb7951b 100644 --- a/.gitignore +++ b/.gitignore @@ -91,6 +91,7 @@ instance/ # Sphinx documentation docs/_build/ +docs/source/_notebooks/* # PyBuilder .pybuilder/ diff --git a/docs/source/conf.py b/docs/source/conf.py index 51102fd04..02f75262b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -34,6 +34,9 @@ 'sphinx.ext.autosummary', 'sphinx.ext.viewcode', 'sphinx.ext.napoleon', + 'myst_nb', + 'sphinx.ext.mathjax', + 'sphinx-mathjax-offline' ] autosummary_generate = True autosummary_imported_members = False @@ -43,6 +46,11 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] source_suffix = {'.rst': 'restructuredtext', '.txt': 'restructuredtext', '.md': 'markdown'} +myst_enable_extensions = [ + "amsmath", + "dollarmath", +] +nb_execution_mode = "off" # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output @@ -69,3 +77,7 @@ }, ], } + +def setup(app): + # forces mathjax on all pages + app.set_html_assets_policy('always') diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 993517fd6..fe24a60a8 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -6,3 +6,7 @@ All of the notebooks can directly be run via binder or colab from the repo. .. toctree:: :maxdepth: 1 + :caption: Contents: + :glob: + + _notebooks/*/* diff --git a/pyproject.toml b/pyproject.toml index 31798a35c..3e1b3a78d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,12 @@ test = [ "pytest-cov", "pytest-xdist", ] -docs = ["sphinx", "sphinx_rtd_theme", "sphinx-pyproject"] +docs = ["sphinx", + "sphinx_rtd_theme", + "sphinx-pyproject", + "myst-nb", + "sphinx-mathjax-offline", + ] notebook = [ "zenodo_get", "ipykernel", From 41ad113d11e675785fad14ac5f6b2a2cc14b7403 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Thu, 14 Nov 2024 12:59:36 +0100 Subject: [PATCH 21/35] Add tests for autodiff of operators (#405) Co-authored-by: Felix F Zimmermann --- tests/__init__.py | 1 + tests/algorithms/csm/test_inati.py | 2 +- tests/algorithms/csm/test_walsh.py | 2 +- tests/data/test_csm_data.py | 2 +- tests/data/test_kdata.py | 2 +- tests/helper.py | 61 ++++++++++++++++--- .../models/test_inversion_recovery.py | 8 +++ tests/operators/models/test_molli.py | 9 ++- .../models/test_mono_exponential_decay.py | 8 +++ .../models/test_saturation_recovery.py | 8 +++ ...transient_steady_state_with_preparation.py | 15 +++++ tests/operators/models/test_wasabi.py | 8 +++ tests/operators/models/test_wasabiti.py | 9 +++ tests/operators/test_autograd_linop.py | 3 +- tests/operators/test_cartesian_sampling_op.py | 3 +- tests/operators/test_constraints_op.py | 14 ++++- .../operators/test_density_compensation_op.py | 3 +- tests/operators/test_einsum_op.py | 3 +- tests/operators/test_fast_fourier_op.py | 3 +- tests/operators/test_finite_difference_op.py | 3 +- tests/operators/test_fourier_op.py | 3 +- tests/operators/test_grid_sampling_op.py | 3 +- tests/operators/test_linearoperatormatrix.py | 3 +- tests/operators/test_magnitude_op.py | 16 +++-- tests/operators/test_operators.py | 3 +- tests/operators/test_pca_compression_op.py | 3 +- tests/operators/test_phase_op.py | 16 +++-- tests/operators/test_rearrangeop.py | 3 +- tests/operators/test_sensitivity_op.py | 3 +- tests/operators/test_slice_projection_op.py | 3 +- tests/operators/test_wavelet_op.py | 5 +- tests/operators/test_zero_op.py | 3 +- tests/operators/test_zero_pad_op.py | 3 +- tests/phantoms/test_ellipse_phantom.py | 2 +- 34 files changed, 176 insertions(+), 60 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index e05dc1d29..675fd5e26 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1,2 @@ from ._RandomGenerator import RandomGenerator +from .helper import relative_image_difference, dotproduct_adjointness_test, operator_isometry_test, linear_operator_unitary_test, autodiff_test diff --git a/tests/algorithms/csm/test_inati.py b/tests/algorithms/csm/test_inati.py index 0e179b2e5..beaa2fa5d 100644 --- a/tests/algorithms/csm/test_inati.py +++ b/tests/algorithms/csm/test_inati.py @@ -3,8 +3,8 @@ import torch from mrpro.algorithms.csm import inati from mrpro.data import SpatialDimension +from tests import relative_image_difference from tests.algorithms.csm.conftest import multi_coil_image -from tests.helper import relative_image_difference def test_inati(ellipse_phantom, random_kheader): diff --git a/tests/algorithms/csm/test_walsh.py b/tests/algorithms/csm/test_walsh.py index 36eee7b60..e08dd6ce3 100644 --- a/tests/algorithms/csm/test_walsh.py +++ b/tests/algorithms/csm/test_walsh.py @@ -3,8 +3,8 @@ import torch from mrpro.algorithms.csm import walsh from mrpro.data import SpatialDimension +from tests import relative_image_difference from tests.algorithms.csm.conftest import multi_coil_image -from tests.helper import relative_image_difference def test_walsh(ellipse_phantom, random_kheader): diff --git a/tests/data/test_csm_data.py b/tests/data/test_csm_data.py index 0246a2a07..bb759a35a 100644 --- a/tests/data/test_csm_data.py +++ b/tests/data/test_csm_data.py @@ -6,8 +6,8 @@ import torch from mrpro.data import CsmData, SpatialDimension +from tests import relative_image_difference from tests.algorithms.csm.test_walsh import multi_coil_image -from tests.helper import relative_image_difference def test_CsmData_is_frozen_dataclass(random_test_data, random_kheader): diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index afb6cb977..fa3e4ebd9 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -10,9 +10,9 @@ from mrpro.operators import FastFourierOp from mrpro.utils import split_idx +from tests import relative_image_difference from tests.conftest import RandomGenerator, generate_random_data from tests.data import IsmrmrdRawTestData -from tests.helper import relative_image_difference from tests.phantoms import EllipsePhantomTestData diff --git a/tests/helper.py b/tests/helper.py index 794f5685a..7e11826a7 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -1,7 +1,8 @@ """Helper/Utilities for test functions.""" import torch -from mrpro.operators import Operator +from mrpro.operators import LinearOperator, Operator +from typing_extensions import TypeVarTuple, Unpack def relative_image_difference(img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: @@ -26,9 +27,13 @@ def relative_image_difference(img1: torch.Tensor, img2: torch.Tensor) -> torch.T def dotproduct_adjointness_test( - operator: Operator, u: torch.Tensor, v: torch.Tensor, relative_tolerance: float = 1e-3, absolute_tolerance=1e-5 + operator: LinearOperator, + u: torch.Tensor, + v: torch.Tensor, + relative_tolerance: float = 1e-3, + absolute_tolerance=1e-5, ): - """Test the adjointness of operator and operator.H. + """Test the adjointness of linear operator and operator.H. Test if == @@ -42,7 +47,7 @@ def dotproduct_adjointness_test( Parameters ---------- operator - operator + linear operator u element of the domain of the operator v @@ -74,9 +79,12 @@ def dotproduct_adjointness_test( def operator_isometry_test( - operator: Operator, u: torch.Tensor, relative_tolerance: float = 1e-3, absolute_tolerance=1e-5 + operator: Operator[torch.Tensor, tuple[torch.Tensor]], + u: torch.Tensor, + relative_tolerance: float = 1e-3, + absolute_tolerance=1e-5, ): - """Test the isometry of an operator. + """Test the isometry of a operator. Test if ||Operator(u)|| == ||u|| @@ -103,10 +111,10 @@ def operator_isometry_test( ) -def operator_unitary_test( - operator: Operator, u: torch.Tensor, relative_tolerance: float = 1e-3, absolute_tolerance=1e-5 +def linear_operator_unitary_test( + operator: LinearOperator, u: torch.Tensor, relative_tolerance: float = 1e-3, absolute_tolerance=1e-5 ): - """Test if an operator is unitary. + """Test if a linear operator is unitary. Test if Operator.adjoint(Operator(u)) == u @@ -115,7 +123,7 @@ def operator_unitary_test( Parameters ---------- operator - operator + linear operator u element of the domain of the operator relative_tolerance @@ -129,3 +137,36 @@ def operator_unitary_test( if the adjointness property does not hold """ torch.testing.assert_close(u, operator.adjoint(operator(u)[0])[0], rtol=relative_tolerance, atol=absolute_tolerance) + + +Tin = TypeVarTuple('Tin') + + +def autodiff_test( + operator: Operator[Unpack[Tin], tuple[torch.Tensor, ...]], + *u: Unpack[Tin], +): + """Test if autodiff of an operator is working. + This test does not check that the gradient is correct but simply that it can be calculated using both torch.func.jvp + and torch.func.vjp. + + Parameters + ---------- + operator + operator + u + element(s) of the domain of the operator + + Raises + ------ + AssertionError + if autodiff fails + """ + # Forward-mode autodiff using jvp + with torch.autograd.detect_anomaly(): + v_range, _ = torch.func.jvp(operator.forward, u, u) + + # Backward-mode autodiff using vjp + with torch.autograd.detect_anomaly(): + (_, vjpfunc) = torch.func.vjp(operator.forward, *u) + vjpfunc(v_range) diff --git a/tests/operators/models/test_inversion_recovery.py b/tests/operators/models/test_inversion_recovery.py index 52f957f8f..b3d32a211 100644 --- a/tests/operators/models/test_inversion_recovery.py +++ b/tests/operators/models/test_inversion_recovery.py @@ -3,6 +3,7 @@ import pytest import torch from mrpro.operators.models import InversionRecovery +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -39,3 +40,10 @@ def test_inversion_recovery_shape(parameter_shape, contrast_dim_shape, signal_sh m0, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) (signal,) = model_op(m0, t1) assert signal.shape == signal_shape + + +def test_autodiff_inversion_recovery(): + """Test autodiff works for inversion_recovery model.""" + model = InversionRecovery(ti=10) + m0, t1 = create_parameter_tensor_tuples(parameter_shape=(2, 5, 10, 10, 10), number_of_tensors=2) + autodiff_test(model, m0, t1) diff --git a/tests/operators/models/test_molli.py b/tests/operators/models/test_molli.py index c92f92556..82b5f6c04 100644 --- a/tests/operators/models/test_molli.py +++ b/tests/operators/models/test_molli.py @@ -3,7 +3,7 @@ import pytest import torch from mrpro.operators.models import MOLLI -from tests import RandomGenerator +from tests import RandomGenerator, autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -45,3 +45,10 @@ def test_molli_shape(parameter_shape, contrast_dim_shape, signal_shape): a, c, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) (signal,) = model_op(a, c, t1) assert signal.shape == signal_shape + + +def test_autodiff_molli(): + """Test autodiff works for molli model.""" + model = MOLLI(ti=10) + a, b, t1 = create_parameter_tensor_tuples((2, 5, 10, 10, 10), number_of_tensors=3) + autodiff_test(model, a, b, t1) diff --git a/tests/operators/models/test_mono_exponential_decay.py b/tests/operators/models/test_mono_exponential_decay.py index d77d5862c..1aba27891 100644 --- a/tests/operators/models/test_mono_exponential_decay.py +++ b/tests/operators/models/test_mono_exponential_decay.py @@ -3,6 +3,7 @@ import pytest import torch from mrpro.operators.models import MonoExponentialDecay +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -41,3 +42,10 @@ def test_mono_exponential_decay_shape(parameter_shape, contrast_dim_shape, signa m0, decay_constant = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) (signal,) = model_op(m0, decay_constant) assert signal.shape == signal_shape + + +def test_autodiff_exponential_decay(): + """Test autodiff works for mono-exponential decay model.""" + model = MonoExponentialDecay(decay_time=20) + m0, decay_constant = create_parameter_tensor_tuples(parameter_shape=(2, 5, 10, 10, 10), number_of_tensors=2) + autodiff_test(model, m0, decay_constant) diff --git a/tests/operators/models/test_saturation_recovery.py b/tests/operators/models/test_saturation_recovery.py index 692d4cc31..0b2406ee6 100644 --- a/tests/operators/models/test_saturation_recovery.py +++ b/tests/operators/models/test_saturation_recovery.py @@ -3,6 +3,7 @@ import pytest import torch from mrpro.operators.models import SaturationRecovery +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -41,3 +42,10 @@ def test_saturation_recovery_shape(parameter_shape, contrast_dim_shape, signal_s m0, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=2) (signal,) = model_op(m0, t1) assert signal.shape == signal_shape + + +def test_autodiff_aturation_recovery(): + """Test autodiff works for aturation recovery model.""" + model = SaturationRecovery(ti=10) + m0, t1 = create_parameter_tensor_tuples((2, 5, 10, 10, 10), number_of_tensors=2) + autodiff_test(model, m0, t1) diff --git a/tests/operators/models/test_transient_steady_state_with_preparation.py b/tests/operators/models/test_transient_steady_state_with_preparation.py index 553a8dbcc..5d43f8eb8 100644 --- a/tests/operators/models/test_transient_steady_state_with_preparation.py +++ b/tests/operators/models/test_transient_steady_state_with_preparation.py @@ -4,6 +4,7 @@ import torch from einops import repeat from mrpro.operators.models import TransientSteadyStateWithPreparation +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -79,3 +80,17 @@ def test_transient_steady_state_shape(parameter_shape, contrast_dim_shape, signa m0, t1, flip_angle = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) (signal,) = model_op(m0, t1, flip_angle) assert signal.shape == signal_shape + + +def test_autodiff_transient_steady_state(): + """Test autodiff works for transient steady state model.""" + contrast_dim_shape = (6,) + (sampling_time,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1) + repetition_time, m0_scaling_preparation, delay_after_preparation = create_parameter_tensor_tuples( + contrast_dim_shape[1:], number_of_tensors=3 + ) + model = TransientSteadyStateWithPreparation( + sampling_time, repetition_time, m0_scaling_preparation, delay_after_preparation + ) + m0, t1, flip_angle = create_parameter_tensor_tuples(parameter_shape=(2, 5, 10, 10, 10), number_of_tensors=3) + autodiff_test(model, m0, t1, flip_angle) diff --git a/tests/operators/models/test_wasabi.py b/tests/operators/models/test_wasabi.py index 3d5c9e312..3e58e0a05 100644 --- a/tests/operators/models/test_wasabi.py +++ b/tests/operators/models/test_wasabi.py @@ -2,6 +2,7 @@ import torch from mrpro.operators.models import WASABI +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -45,3 +46,10 @@ def test_WASABI_shape(parameter_shape, contrast_dim_shape, signal_shape): b0_shift, rb1, c, d = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=4) (signal,) = model_op(b0_shift, rb1, c, d) assert signal.shape == signal_shape + + +def test_autodiff_WASABI(): + """Test autodiff works for WASABI model.""" + offset, b0_shift, rb1, c, d = create_data(offset_max=300, n_offsets=2) + wasabi_model = WASABI(offsets=offset) + autodiff_test(wasabi_model, b0_shift, rb1, c, d) diff --git a/tests/operators/models/test_wasabiti.py b/tests/operators/models/test_wasabiti.py index eeedf9463..637f9ff9e 100644 --- a/tests/operators/models/test_wasabiti.py +++ b/tests/operators/models/test_wasabiti.py @@ -3,6 +3,7 @@ import pytest import torch from mrpro.operators.models import WASABITI +from tests import autodiff_test from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples @@ -76,3 +77,11 @@ def test_WASABITI_shape(parameter_shape, contrast_dim_shape, signal_shape): b0_shift, rb1, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) (signal,) = model_op(b0_shift, rb1, t1) assert signal.shape == signal_shape + + +def test_autodiff_WASABITI(): + """Test autodiff works for WASABITI model.""" + offset, b0_shift, rb1, t1 = create_data(offset_max=300, n_offsets=2) + trec = torch.ones_like(offset) * t1 + wasabiti_model = WASABITI(offsets=offset, trec=trec) + autodiff_test(wasabiti_model, b0_shift, rb1, t1) diff --git a/tests/operators/test_autograd_linop.py b/tests/operators/test_autograd_linop.py index f2df2d1af..da833de98 100644 --- a/tests/operators/test_autograd_linop.py +++ b/tests/operators/test_autograd_linop.py @@ -5,8 +5,7 @@ from mrpro.operators import LinearOperator from torch.autograd.gradcheck import GradcheckError -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test class NonDifferentiableOperator(LinearOperator, adjoint_as_backward=False): diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 5bff62dab..49959f911 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -6,9 +6,8 @@ from mrpro.data import KTrajectory, SpatialDimension from mrpro.operators import CartesianSamplingOp -from tests import RandomGenerator +from tests import RandomGenerator, dotproduct_adjointness_test from tests.conftest import create_traj -from tests.helper import dotproduct_adjointness_test def test_cart_sampling_op_data_match(): diff --git a/tests/operators/test_constraints_op.py b/tests/operators/test_constraints_op.py index 5d0ba55b0..b11f8f6d2 100644 --- a/tests/operators/test_constraints_op.py +++ b/tests/operators/test_constraints_op.py @@ -4,7 +4,7 @@ import torch from mrpro.operators import ConstraintsOp -from tests import RandomGenerator +from tests import RandomGenerator, autodiff_test @pytest.mark.parametrize( @@ -141,3 +141,15 @@ def test_constraints_operator_multiple_inputs(bounds): def test_constraints_operator_illegal_bounds(bounds): with pytest.raises(ValueError, match='invalid'): ConstraintsOp(bounds) + + +def test_autodiff_constraints_operator(): + """Test autodiff works for constraints operator.""" + # random tensors with arbitrary values + random_generator = RandomGenerator(seed=0) + x1 = random_generator.float32_tensor(size=(36, 72), low=-1, high=1) + x2 = random_generator.float32_tensor(size=(36, 72), low=-1, high=1) + x3 = random_generator.float32_tensor(size=(36, 72), low=-1, high=1) + + constraints_op = ConstraintsOp(bounds=((None, None), (1.0, None), (None, 1.0))) + autodiff_test(constraints_op, x1, x2, x3) diff --git a/tests/operators/test_density_compensation_op.py b/tests/operators/test_density_compensation_op.py index 616d0e8f9..96b59547e 100644 --- a/tests/operators/test_density_compensation_op.py +++ b/tests/operators/test_density_compensation_op.py @@ -4,8 +4,7 @@ from mrpro.data import DcfData from mrpro.operators import DensityCompensationOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def test_density_compensation_op_adjointness(): diff --git a/tests/operators/test_einsum_op.py b/tests/operators/test_einsum_op.py index 1db7c0ac3..8e098ab32 100644 --- a/tests/operators/test_einsum_op.py +++ b/tests/operators/test_einsum_op.py @@ -4,8 +4,7 @@ import torch from mrpro.operators.EinsumOp import EinsumOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize('dtype', ['float32', 'complex128']) diff --git a/tests/operators/test_fast_fourier_op.py b/tests/operators/test_fast_fourier_op.py index 7e2f47fd7..b7fd94576 100644 --- a/tests/operators/test_fast_fourier_op.py +++ b/tests/operators/test_fast_fourier_op.py @@ -6,8 +6,7 @@ from mrpro.data import SpatialDimension from mrpro.operators import FastFourierOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize(('npoints', 'a'), [(100, 20), (300, 20)]) diff --git a/tests/operators/test_finite_difference_op.py b/tests/operators/test_finite_difference_op.py index f79da441b..ea21ae919 100644 --- a/tests/operators/test_finite_difference_op.py +++ b/tests/operators/test_finite_difference_op.py @@ -5,8 +5,7 @@ from einops import repeat from mrpro.operators import FiniteDifferenceOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize('mode', ['central', 'forward', 'backward']) diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 5eccbbc1b..826ea24f6 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -6,9 +6,8 @@ from mrpro.data.traj_calculators import KTrajectoryCartesian from mrpro.operators import FourierOp -from tests import RandomGenerator +from tests import RandomGenerator, dotproduct_adjointness_test from tests.conftest import COMMON_MR_TRAJECTORIES, create_traj -from tests.helper import dotproduct_adjointness_test def create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): diff --git a/tests/operators/test_grid_sampling_op.py b/tests/operators/test_grid_sampling_op.py index ed020956a..c8d2ccaf0 100644 --- a/tests/operators/test_grid_sampling_op.py +++ b/tests/operators/test_grid_sampling_op.py @@ -8,8 +8,7 @@ from mrpro.operators import GridSamplingOp from torch.autograd.gradcheck import gradcheck -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize('dtype', ['float32', 'float64', 'complex64']) diff --git a/tests/operators/test_linearoperatormatrix.py b/tests/operators/test_linearoperatormatrix.py index 7ba87d715..2a669d799 100644 --- a/tests/operators/test_linearoperatormatrix.py +++ b/tests/operators/test_linearoperatormatrix.py @@ -5,8 +5,7 @@ from mrpro.operators import EinsumOp, LinearOperator, MagnitudeOp from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def random_linearop(size, rng): diff --git a/tests/operators/test_magnitude_op.py b/tests/operators/test_magnitude_op.py index 88fc28209..d4cab4974 100644 --- a/tests/operators/test_magnitude_op.py +++ b/tests/operators/test_magnitude_op.py @@ -3,15 +3,23 @@ import torch from mrpro.operators import MagnitudeOp -from tests import RandomGenerator +from tests import RandomGenerator, autodiff_test def test_magnitude_operator_forward(): """Test that MagnitudeOp returns abs of tensors.""" - rng = RandomGenerator(2) - a = rng.complex64_tensor((2, 3)) - b = rng.complex64_tensor((3, 10)) + random_generator = RandomGenerator(seed=2) + a = random_generator.complex64_tensor((2, 3)) + b = random_generator.complex64_tensor((3, 10)) magnitude_op = MagnitudeOp() magnitude_a, magnitude_b = magnitude_op(a, b) assert torch.allclose(magnitude_a, torch.abs(a)) assert torch.allclose(magnitude_b, torch.abs(b)) + + +def test_autodiff_magnitude_operator(): + """Test autodiff works for magnitude operator.""" + random_generator = RandomGenerator(seed=2) + a = random_generator.complex64_tensor((5, 9, 8)) + b = random_generator.complex64_tensor((10, 11, 12)) + autodiff_test(MagnitudeOp(), a, b) diff --git a/tests/operators/test_operators.py b/tests/operators/test_operators.py index c658e769d..378060b74 100644 --- a/tests/operators/test_operators.py +++ b/tests/operators/test_operators.py @@ -7,8 +7,7 @@ from mrpro.operators import LinearOperator, Operator from typing_extensions import Any, assert_type -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test class DummyOperator(Operator[torch.Tensor, tuple[torch.Tensor,]]): diff --git a/tests/operators/test_pca_compression_op.py b/tests/operators/test_pca_compression_op.py index a600bcecb..e73bf3951 100644 --- a/tests/operators/test_pca_compression_op.py +++ b/tests/operators/test_pca_compression_op.py @@ -3,8 +3,7 @@ import pytest from mrpro.operators import PCACompressionOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize( diff --git a/tests/operators/test_phase_op.py b/tests/operators/test_phase_op.py index aadfbdf24..726569312 100644 --- a/tests/operators/test_phase_op.py +++ b/tests/operators/test_phase_op.py @@ -3,15 +3,23 @@ import torch from mrpro.operators import PhaseOp -from tests import RandomGenerator +from tests import RandomGenerator, autodiff_test def test_phase_operator_forward(): """Test that PhaseOp returns angle of tensors.""" - rng = RandomGenerator(2) - a = rng.complex64_tensor((2, 3)) - b = rng.complex64_tensor((3, 10)) + random_generator = RandomGenerator(seed=2) + a = random_generator.complex64_tensor((2, 3)) + b = random_generator.complex64_tensor((3, 10)) phase_op = PhaseOp() phase_a, phase_b = phase_op(a, b) assert torch.allclose(phase_a, torch.angle(a)) assert torch.allclose(phase_b, torch.angle(b)) + + +def test_autodiff_magnitude_operator(): + """Test autodiff works for magnitude operator.""" + random_generator = RandomGenerator(seed=2) + a = random_generator.complex64_tensor((5, 9, 8)) + b = random_generator.complex64_tensor((10, 11, 12)) + autodiff_test(PhaseOp(), a, b) diff --git a/tests/operators/test_rearrangeop.py b/tests/operators/test_rearrangeop.py index 054c99402..ecacafb42 100644 --- a/tests/operators/test_rearrangeop.py +++ b/tests/operators/test_rearrangeop.py @@ -3,8 +3,7 @@ import pytest from mrpro.operators.RearrangeOp import RearrangeOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test @pytest.mark.parametrize('dtype', ['float32', 'complex128']) diff --git a/tests/operators/test_sensitivity_op.py b/tests/operators/test_sensitivity_op.py index 1321498d8..5576a9892 100644 --- a/tests/operators/test_sensitivity_op.py +++ b/tests/operators/test_sensitivity_op.py @@ -5,8 +5,7 @@ from mrpro.data import CsmData, QHeader, SpatialDimension from mrpro.operators import SensitivityOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def test_sensitivity_op_adjointness(): diff --git a/tests/operators/test_slice_projection_op.py b/tests/operators/test_slice_projection_op.py index 09a296d4c..53209e98d 100644 --- a/tests/operators/test_slice_projection_op.py +++ b/tests/operators/test_slice_projection_op.py @@ -9,8 +9,7 @@ from mrpro.operators import SliceProjectionOp from mrpro.utils.slice_profiles import SliceGaussian, SliceInterpolate, SliceSmoothedRectangular -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def test_slice_projection_op_cube_basic(): diff --git a/tests/operators/test_wavelet_op.py b/tests/operators/test_wavelet_op.py index 46e90da46..92de01286 100644 --- a/tests/operators/test_wavelet_op.py +++ b/tests/operators/test_wavelet_op.py @@ -8,8 +8,7 @@ from ptwt.conv_transform_2 import wavedec2 from ptwt.conv_transform_3 import wavedec3 -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test, operator_isometry_test, operator_unitary_test +from tests import RandomGenerator, dotproduct_adjointness_test, linear_operator_unitary_test, operator_isometry_test @pytest.mark.parametrize( @@ -168,4 +167,4 @@ def test_wavelet_op_unitary(im_shape, domain_shape, dim, wavelet_name): random_generator = RandomGenerator(seed=0) img = random_generator.complex64_tensor(size=im_shape) wavelet_op = WaveletOp(domain_shape=domain_shape, dim=dim, wavelet_name=wavelet_name) - operator_unitary_test(wavelet_op, img) + linear_operator_unitary_test(wavelet_op, img) diff --git a/tests/operators/test_zero_op.py b/tests/operators/test_zero_op.py index b47a82998..1e7f47017 100644 --- a/tests/operators/test_zero_op.py +++ b/tests/operators/test_zero_op.py @@ -3,8 +3,7 @@ from mrpro.operators.LinearOperator import LinearOperatorSum from typing_extensions import assert_type -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def test_zero_op_keepshape(): diff --git a/tests/operators/test_zero_pad_op.py b/tests/operators/test_zero_pad_op.py index ce2d8855d..5d1f02135 100644 --- a/tests/operators/test_zero_pad_op.py +++ b/tests/operators/test_zero_pad_op.py @@ -4,8 +4,7 @@ import torch from mrpro.operators import ZeroPadOp -from tests import RandomGenerator -from tests.helper import dotproduct_adjointness_test +from tests import RandomGenerator, dotproduct_adjointness_test def test_zero_pad_op_content(): diff --git a/tests/phantoms/test_ellipse_phantom.py b/tests/phantoms/test_ellipse_phantom.py index 2b34b06b5..c7ab58850 100644 --- a/tests/phantoms/test_ellipse_phantom.py +++ b/tests/phantoms/test_ellipse_phantom.py @@ -5,7 +5,7 @@ from mrpro.data import SpatialDimension from mrpro.operators import FastFourierOp -from tests.helper import relative_image_difference +from tests import relative_image_difference def test_image_space(ellipse_phantom): From 96f66aaeaa0e2b52eafa786ce5b7779b04dad437 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Thu, 14 Nov 2024 13:35:49 +0100 Subject: [PATCH 22/35] Implement Gram Shortcut for FourierOp (#503) --- src/mrpro/operators/CartesianSamplingOp.py | 66 +++++++ src/mrpro/operators/FourierOp.py | 161 ++++++++++++++++++ src/mrpro/utils/zero_pad_or_crop.py | 2 +- tests/operators/test_cartesian_sampling_op.py | 99 ++++++++--- tests/operators/test_fourier_op.py | 19 +++ 5 files changed, 319 insertions(+), 28 deletions(-) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 07f8aba65..47c71c77f 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -212,3 +212,69 @@ def _broadcast_and_scatter_along_last_dim( ).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter) return data_scattered + + @property + def gram(self) -> 'CartesianSamplingGramOp': + """Return the Gram operator for this Cartesian Sampling Operator. + + Returns + ------- + Gram operator for this Cartesian Sampling Operator + """ + return CartesianSamplingGramOp(self) + + +class CartesianSamplingGramOp(LinearOperator): + """Gram operator for Cartesian Sampling Operator. + + The Gram operator is the composition CartesianSamplingOp.H @ CartesianSamplingOp. + """ + + def __init__(self, sampling_op: CartesianSamplingOp): + """Initialize Cartesian Sampling Gram Operator class. + + This should not be used directly, but rather through the `gram` method of a + :class:`mrpro.operator.CartesianSamplingOp` object. + + Parameters + ---------- + sampling_op + The Cartesian Sampling Operator for which to create the Gram operator. + """ + super().__init__() + if sampling_op._needs_indexing: + ones = torch.ones(*sampling_op._trajectory_shape[:-3], 1, *sampling_op._sorted_grid_shape.zyx) + (mask,) = sampling_op.adjoint(*sampling_op.forward(ones)) + self.register_buffer('_mask', mask) + else: + self._mask: torch.Tensor | None = None + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the Gram operator. + + Parameters + ---------- + x + Input data + + Returns + ------- + Output data + """ + if self._mask is None: + return (x,) + return (x * self._mask,) + + def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint of the Gram operator. + + Parameters + ---------- + y + Input data + + Returns + ------- + Output data + """ + return self.forward(y) diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index cacdda1dc..a3e81aba7 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -1,6 +1,7 @@ """Fourier Operator.""" from collections.abc import Sequence +from itertools import product import numpy as np import torch @@ -223,3 +224,163 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.permute(*unpermute) return (x,) + + @property + def gram(self) -> LinearOperator: + """Return the gram operator.""" + return FourierGramOp(self) + + +def symmetrize(kernel: torch.Tensor, rank: int) -> torch.Tensor: + """Enforce hermitian symmetry on the kernel. Returns only half of the kernel.""" + flipped = kernel.clone() + for d in range(-rank, 0): + flipped = flipped.index_select(d, -1 * torch.arange(flipped.shape[d], device=flipped.device) % flipped.size(d)) + kernel = (kernel + flipped.conj()) / 2 + last_len = kernel.shape[-1] + return kernel[..., : last_len // 2 + 1] + + +def gram_nufft_kernel(weight: torch.Tensor, trajectory: torch.Tensor, recon_shape: Sequence[int]) -> torch.Tensor: + """Calculate the convolution kernel for the NUFFT gram operator. + + Parameters + ---------- + weight + either ones or density compensation weights + trajectory + k-space trajectory + recon_shape + shape of the reconstructed image + + Returns + ------- + kernel + real valued convolution kernel for the NUFFT gram operator, already in Fourier space + """ + rank = trajectory.shape[-2] + if rank != len(recon_shape): + raise ValueError('Rank of trajectory and image size must match.') + # Instead of doing one adjoint nufft with double the recon size in all dimensions, + # we do two adjoint nuffts per dimensions, saving a lot of memory. + adjnufft_ob = KbNufftAdjoint(im_size=recon_shape, n_shift=[0] * rank).to(trajectory) + + kernel = adjnufft_ob(weight, trajectory) # this will be the top left ... corner block + pad = [] + for s in kernel.shape[: -rank - 1 : -1]: + pad.extend([0, s]) + kernel = torch.nn.functional.pad(kernel, pad) # twice the size in all dimensions + + for flips in list(product([1, -1], repeat=rank)): + if all(flip == 1 for flip in flips): + # top left ... block already processed before padding + continue + flipped_trajectory = trajectory * torch.tensor(flips).to(trajectory).unsqueeze(-1) + kernel_part = adjnufft_ob(weight, flipped_trajectory) + slices = [] # which part of the kernel to is currently being processed + for dim, flip in zip(range(-rank, 0), flips, strict=True): + if flip > 0: # first half in the dimension + slices.append(slice(0, kernel_part.size(dim))) + else: # second half in the dimension + slices.append(slice(kernel_part.size(dim) + 1, None)) + kernel_part = kernel_part.index_select(dim, torch.arange(kernel_part.size(dim) - 1, 0, -1)) # flip + + kernel[[..., *slices]] = kernel_part + + kernel = symmetrize(kernel, rank) + kernel = torch.fft.hfftn(kernel, dim=list(range(-rank, 0)), norm='backward') + kernel /= kernel.shape[-rank:].numel() + kernel = torch.fft.fftshift(kernel, dim=list(range(-rank, 0))) + return kernel + + +class FourierGramOp(LinearOperator): + """Gram operator for the Fourier operator. + + Implements the adjoint of the forward operator of the Fourier operator, i.e. the gram operator + `F.H@F. + + Uses a convolution, implemented as multiplication in Fourier space, to calculate the gram operator + for the toeplitz NUFFT operator. + + Uses a multiplication with a binary mask in Fourier space to calculate the gram operator for + the Cartesian FFT operator + + This Operator is only used internally and should not be used directly. + Instead, consider using the `gram` property of :class: `mrpro.operators.FourierOp`. + """ + + _kernel: torch.Tensor | None + + def __init__(self, fourier_op: FourierOp) -> None: + """Initialize the gram operator. + + If density compensation weights are provided, they the operator + F.H@dcf@F is calculated. + + Parameters + ---------- + fourier_op + the Fourier operator to calculate the gram operator for + + """ + super().__init__() + if fourier_op._nufft_dims and fourier_op._omega is not None: + weight = torch.ones_like(fourier_op._omega[..., :1, :, :, :]) + keep_dims = [-4, *fourier_op._nufft_dims] # -4 is coil + permute = [i for i in range(-weight.ndim, 0) if i not in keep_dims] + keep_dims + unpermute = np.argsort(permute) + weight = weight.permute(*permute) + weight_unflattend_shape = weight.shape + weight = weight.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) + weight = weight + 0j + omega = fourier_op._omega.permute(*permute) + omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) + kernel = gram_nufft_kernel(weight, omega, fourier_op._nufft_im_size) + kernel = kernel.reshape(*weight_unflattend_shape[: -len(keep_dims)], *kernel.shape[-len(keep_dims) :]) + kernel = kernel.permute(*unpermute) + fft = FastFourierOp( + dim=fourier_op._nufft_dims, + encoding_matrix=[2 * s for s in fourier_op._nufft_im_size], + recon_matrix=fourier_op._nufft_im_size, + ) + self.nufft_gram: None | LinearOperator = fft.H * kernel @ fft + else: + self.nufft_gram = None + + if fourier_op._fast_fourier_op is not None and fourier_op._cart_sampling_op is not None: + self.fast_fourier_gram: None | LinearOperator = ( + fourier_op._fast_fourier_op.H @ fourier_op._cart_sampling_op.gram @ fourier_op._fast_fourier_op + ) + else: + self.fast_fourier_gram = None + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the operator to the input tensor. + + Parameters + ---------- + x + input tensor, shape (..., coils, z, y, x) + """ + if self.nufft_gram is not None: + (x,) = self.nufft_gram(x) + + if self.fast_fourier_gram is not None: + (x,) = self.fast_fourier_gram(x) + return (x,) + + def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint operator to the input tensor. + + Parameters + ---------- + x + input tensor, shape (..., coils, k2, k1, k0) + """ + return self.forward(x) + + @property + def H(self) -> Self: # noqa: N802 + """Adjoint operator of the gram operator.""" + return self diff --git a/src/mrpro/utils/zero_pad_or_crop.py b/src/mrpro/utils/zero_pad_or_crop.py index 42adda430..23fb39599 100644 --- a/src/mrpro/utils/zero_pad_or_crop.py +++ b/src/mrpro/utils/zero_pad_or_crop.py @@ -35,7 +35,7 @@ def zero_pad_or_crop( new_shape: Sequence[int] | torch.Size, dim: None | Sequence[int] = None, ) -> torch.Tensor: - """Change shape of data by cropping or zero-padding. + """Change shape of data by center cropping or symmetric zero-padding. Parameters ---------- diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 49959f911..0fd320212 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -5,6 +5,7 @@ from einops import rearrange from mrpro.data import KTrajectory, SpatialDimension from mrpro.operators import CartesianSamplingOp +from typing_extensions import Unpack from tests import RandomGenerator, dotproduct_adjointness_test from tests.conftest import create_traj @@ -50,33 +51,11 @@ def test_cart_sampling_op_data_match(): torch.testing.assert_close(kdata[:, :, ::2, ::4, ::3], k_sub[:, :, ::2, ::4, ::3]) -@pytest.mark.parametrize( - 'sampling', - [ - 'random', - 'partial_echo', - 'partial_fourier', - 'regular_undersampling', - 'random_undersampling', - 'different_random_undersampling', - 'cartesian_and_non_cartesian', - 'kx_ky_along_k0', - 'kx_ky_along_k0_undersampling', - ], -) -def test_cart_sampling_op_fwd_adj(sampling): - """Test adjoint property of Cartesian sampling operator.""" - - # Create 3D uniform trajectory - k_shape = (2, 5, 20, 40, 60) - nkx = (2, 1, 1, 60) - nky = (2, 1, 40, 1) - nkz = (2, 20, 1, 1) - type_kx = 'uniform' - type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' - type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' - trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz).as_tensor() - +def subsample_traj( + trajectory: KTrajectory, sampling: str, k_shape: tuple[int, int, int, Unpack[tuple[int, ...]]] +) -> KTrajectory: + """Subsample trajectory based on sampling type.""" + trajectory_tensor = trajectory.as_tensor() # Subsample data and trajectory match sampling: case 'random': @@ -108,6 +87,36 @@ def test_cart_sampling_op_fwd_adj(sampling): trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]]) case _: raise NotImplementedError(f'Test {sampling} not implemented.') + return trajectory + + +@pytest.mark.parametrize( + 'sampling', + [ + 'random', + 'partial_echo', + 'partial_fourier', + 'regular_undersampling', + 'random_undersampling', + 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', + ], +) +def test_cart_sampling_op_fwd_adj(sampling): + """Test adjoint property of Cartesian sampling operator.""" + + # Create 3D uniform trajectory + k_shape = (2, 5, 20, 40, 60) + nkx = (2, 1, 1, 60) + nky = (2, 1, 40, 1) + nkz = (2, 20, 1, 1) + type_kx = 'uniform' + type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) + trajectory = subsample_traj(trajectory, sampling, k_shape) encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) @@ -119,6 +128,42 @@ def test_cart_sampling_op_fwd_adj(sampling): dotproduct_adjointness_test(sampling_op, u, v) +@pytest.mark.parametrize( + 'sampling', + [ + 'random', + 'partial_echo', + 'partial_fourier', + 'regular_undersampling', + 'random_undersampling', + 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', + ], +) +def test_cart_sampling_op_gram(sampling): + """Test adjoint gram of Cartesian sampling operator.""" + + # Create 3D uniform trajectory + k_shape = (2, 5, 20, 40, 60) + nkx = (2, 1, 1, 60) + nky = (2, 1, 40, 1) + nkz = (2, 20, 1, 1) + type_kx = 'uniform' + type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) + trajectory = subsample_traj(trajectory, sampling, k_shape) + + encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) + u = RandomGenerator(seed=0).complex64_tensor(size=k_shape) + (expected,) = (sampling_op.H @ sampling_op)(u) + (actual,) = sampling_op.gram(u) + torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize(('k2_min', 'k2_max'), [(-1, 21), (-21, 1)]) @pytest.mark.parametrize(('k0_min', 'k0_max'), [(-6, 13), (-13, 6)]) def test_cart_sampling_op_oversampling(k0_min, k0_max, k2_min, k2_max): diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 826ea24f6..c7c58c266 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -48,6 +48,25 @@ def test_fourier_op_fwd_adj_property( dotproduct_adjointness_test(fourier_op, u, v) +@COMMON_MR_TRAJECTORIES +def test_fourier_op_gram(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): + """Test gram of Fourier operator.""" + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) + + recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) + fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) + + (expected,) = (fourier_op.H @ fourier_op)(img) + (actual,) = fourier_op.gram(img) + + torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize( ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz'), # parameter names [ From d9e649adf031dab5d101513569bd35fce2a1fbdb Mon Sep 17 00:00:00 2001 From: rkcatarina <79782399+rkcatarina@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:03:15 +0100 Subject: [PATCH 23/35] Remove MSEDataConsistency (#450) Co-authored-by: Catarina Redshaw Kranich Co-authored-by: Christoph Kolbitsch --- README.md | 2 +- examples/qmri_sg_challenge_2024_t1.ipynb | 4 +- examples/qmri_sg_challenge_2024_t1.py | 4 +- examples/qmri_sg_challenge_2024_t2_star.ipynb | 4 +- examples/qmri_sg_challenge_2024_t2_star.py | 4 +- examples/t1_mapping_with_grad_acq.ipynb | 6 +-- examples/t1_mapping_with_grad_acq.py | 6 +-- pyproject.toml | 1 + src/mrpro/operators/functionals/L1Norm.py | 2 + .../operators/functionals/L1NormViewAsReal.py | 2 + .../operators/functionals/L2NormSquared.py | 2 + src/mrpro/operators/functionals/MSE.py | 47 ++++++++++++++++ .../functionals/MSEDataDiscrepancy.py | 54 ------------------- src/mrpro/operators/functionals/__init__.py | 4 +- tests/operators/functionals/__init__.py | 3 +- .../operators/functionals/test_functionals.py | 15 +++++- .../functionals/test_mse_functional.py | 32 ----------- 17 files changed, 87 insertions(+), 105 deletions(-) create mode 100644 src/mrpro/operators/functionals/MSE.py delete mode 100644 src/mrpro/operators/functionals/MSEDataDiscrepancy.py delete mode 100644 tests/operators/functionals/test_mse_functional.py diff --git a/README.md b/README.md index fddeb28eb..422c4ef4d 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Quantitative parameter maps can be obtained by creating a functional to be minim # Define signal model model = MagnitudeOp() @ InversionRecovery(ti=idata_multi_ti.header.ti) # Define loss function and combine with signal model -mse = MSEDataDiscrepancy(idata_multi_ti.data.abs()) +mse = MSE(idata_multi_ti.data.abs()) functional = mse @ model [...] # Run optimization diff --git a/examples/qmri_sg_challenge_2024_t1.ipynb b/examples/qmri_sg_challenge_2024_t1.ipynb index 1832e5b48..98441ca8d 100644 --- a/examples/qmri_sg_challenge_2024_t1.ipynb +++ b/examples/qmri_sg_challenge_2024_t1.ipynb @@ -29,7 +29,7 @@ "from mrpro.algorithms.optimizers import adam\n", "from mrpro.data import IData\n", "from mrpro.operators import MagnitudeOp\n", - "from mrpro.operators.functionals import MSEDataDiscrepancy\n", + "from mrpro.operators.functionals import MSE\n", "from mrpro.operators.models import InversionRecovery" ] }, @@ -150,7 +150,7 @@ "metadata": {}, "outputs": [], "source": [ - "mse = MSEDataDiscrepancy(idata_multi_ti.data.abs())" + "mse = MSE(idata_multi_ti.data.abs())" ] }, { diff --git a/examples/qmri_sg_challenge_2024_t1.py b/examples/qmri_sg_challenge_2024_t1.py index 148f5bdd0..ee146cb20 100644 --- a/examples/qmri_sg_challenge_2024_t1.py +++ b/examples/qmri_sg_challenge_2024_t1.py @@ -16,7 +16,7 @@ from mrpro.algorithms.optimizers import adam from mrpro.data import IData from mrpro.operators import MagnitudeOp -from mrpro.operators.functionals import MSEDataDiscrepancy +from mrpro.operators.functionals import MSE from mrpro.operators.models import InversionRecovery # %% [markdown] @@ -71,7 +71,7 @@ # As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal # model $q$. # %% -mse = MSEDataDiscrepancy(idata_multi_ti.data.abs()) +mse = MSE(idata_multi_ti.data.abs()) # %% [markdown] # Now we can simply combine the two into a functional to solve diff --git a/examples/qmri_sg_challenge_2024_t2_star.ipynb b/examples/qmri_sg_challenge_2024_t2_star.ipynb index e2adcde8c..1fda8b92c 100644 --- a/examples/qmri_sg_challenge_2024_t2_star.ipynb +++ b/examples/qmri_sg_challenge_2024_t2_star.ipynb @@ -28,7 +28,7 @@ "from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped]\n", "from mrpro.algorithms.optimizers import adam\n", "from mrpro.data import IData\n", - "from mrpro.operators.functionals import MSEDataDiscrepancy\n", + "from mrpro.operators.functionals import MSE\n", "from mrpro.operators.models import MonoExponentialDecay" ] }, @@ -164,7 +164,7 @@ "metadata": {}, "outputs": [], "source": [ - "mse = MSEDataDiscrepancy(idata_multi_te.data)" + "mse = MSE(idata_multi_te.data)" ] }, { diff --git a/examples/qmri_sg_challenge_2024_t2_star.py b/examples/qmri_sg_challenge_2024_t2_star.py index ced49ae49..e7e28372f 100644 --- a/examples/qmri_sg_challenge_2024_t2_star.py +++ b/examples/qmri_sg_challenge_2024_t2_star.py @@ -15,7 +15,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped] from mrpro.algorithms.optimizers import adam from mrpro.data import IData -from mrpro.operators.functionals import MSEDataDiscrepancy +from mrpro.operators.functionals import MSE from mrpro.operators.models import MonoExponentialDecay # %% [markdown] @@ -78,7 +78,7 @@ # As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal # model $q$. # %% -mse = MSEDataDiscrepancy(idata_multi_te.data) +mse = MSE(idata_multi_te.data) # %% [markdown] # Now we can simply combine the two into a functional which will then solve diff --git a/examples/t1_mapping_with_grad_acq.ipynb b/examples/t1_mapping_with_grad_acq.ipynb index 743f7ad8e..d46eddc73 100644 --- a/examples/t1_mapping_with_grad_acq.ipynb +++ b/examples/t1_mapping_with_grad_acq.ipynb @@ -29,7 +29,7 @@ "from mrpro.data import KData\n", "from mrpro.data.traj_calculators import KTrajectoryIsmrmrd\n", "from mrpro.operators import ConstraintsOp, MagnitudeOp\n", - "from mrpro.operators.functionals import MSEDataDiscrepancy\n", + "from mrpro.operators.functionals import MSE\n", "from mrpro.operators.models import TransientSteadyStateWithPreparation\n", "from mrpro.utils import split_idx" ] @@ -317,7 +317,7 @@ }, "source": [ "### Loss function\n", - "As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal\n", + "As a loss function for the optimizer, we calculate the mean squared error between the image data $x$ and our signal\n", "model $q$." ] }, @@ -328,7 +328,7 @@ "metadata": {}, "outputs": [], "source": [ - "mse_loss = MSEDataDiscrepancy(img_rss_dynamic)" + "mse_loss = MSE(img_rss_dynamic)" ] }, { diff --git a/examples/t1_mapping_with_grad_acq.py b/examples/t1_mapping_with_grad_acq.py index 29c08b031..3a355ec11 100644 --- a/examples/t1_mapping_with_grad_acq.py +++ b/examples/t1_mapping_with_grad_acq.py @@ -16,7 +16,7 @@ from mrpro.data import KData from mrpro.data.traj_calculators import KTrajectoryIsmrmrd from mrpro.operators import ConstraintsOp, MagnitudeOp -from mrpro.operators.functionals import MSEDataDiscrepancy +from mrpro.operators.functionals import MSE from mrpro.operators.models import TransientSteadyStateWithPreparation from mrpro.utils import split_idx @@ -173,10 +173,10 @@ # %% [markdown] # ### Loss function -# As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal +# As a loss function for the optimizer, we calculate the mean squared error between the image data $x$ and our signal # model $q$. # %% -mse_loss = MSEDataDiscrepancy(img_rss_dynamic) +mse_loss = MSE(img_rss_dynamic) # %% [markdown] # Now we can simply combine the loss function, the signal model and the constraints to solve diff --git a/pyproject.toml b/pyproject.toml index 3e1b3a78d..a65b1cff3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ authors = [ { name = "Johannes Hammacher", email = "johannnes.hammacher@ptb.de" }, { name = "Stefan Martin", email = "stefan.martin@ptb.de" }, { name = "Andreas Kofler", email = "andreas.kofler@ptb.de" }, + { name = "Catarina Redshaw Kranich", email = "catarina.redshaw-kranich@ptb.de" }, ] classifiers = [ "License :: OSI Approved :: Apache Software License", diff --git a/src/mrpro/operators/functionals/L1Norm.py b/src/mrpro/operators/functionals/L1Norm.py index 20380a9eb..29f7b753c 100644 --- a/src/mrpro/operators/functionals/L1Norm.py +++ b/src/mrpro/operators/functionals/L1Norm.py @@ -13,6 +13,8 @@ class L1Norm(ElementaryProximableFunctional): where W is a either a scalar or tensor that corresponds to a (block-) diagonal operator that is applied to the input. + In most cases, consider setting divide_by_n to true to be independent of input size. + The norm of the vector is computed along the dimensions given at initialization. """ diff --git a/src/mrpro/operators/functionals/L1NormViewAsReal.py b/src/mrpro/operators/functionals/L1NormViewAsReal.py index d8aba9dac..e4227c70b 100644 --- a/src/mrpro/operators/functionals/L1NormViewAsReal.py +++ b/src/mrpro/operators/functionals/L1NormViewAsReal.py @@ -15,6 +15,8 @@ class L1NormViewAsReal(ElementaryProximableFunctional): If the parameter `weight` is real-valued, :math:`W_r` and :math:`W_i` are both set to `weight`. If it is complex-valued, :math:`W_r` and :math:`W_I` are set to the real and imaginary part, respectively. + In most cases, consider setting divide_by_n to true to be independent of input size. + The norm of the vector is computed along the dimensions set at initialization. """ diff --git a/src/mrpro/operators/functionals/L2NormSquared.py b/src/mrpro/operators/functionals/L2NormSquared.py index 275e71bf9..c8d001f97 100644 --- a/src/mrpro/operators/functionals/L2NormSquared.py +++ b/src/mrpro/operators/functionals/L2NormSquared.py @@ -15,6 +15,8 @@ class L2NormSquared(ElementaryProximableFunctional): reconstruction when using a density-compensation function for k-space pre-conditioning, for masking of image data, or for spatially varying regularization weights. + In most cases, consider setting divide_by_n to true to be independent of input size. + Alternatively the functional :class:`mrpro.operators.functionals.MSE` can be used. The norm is computed along the dimensions given at initialization, all other dimensions are considered batch dimensions. """ diff --git a/src/mrpro/operators/functionals/MSE.py b/src/mrpro/operators/functionals/MSE.py new file mode 100644 index 000000000..6a8166863 --- /dev/null +++ b/src/mrpro/operators/functionals/MSE.py @@ -0,0 +1,47 @@ +"""MSE-Functional.""" + +from collections.abc import Sequence + +import torch + +from mrpro.operators.functionals.L2NormSquared import L2NormSquared + + +class MSE(L2NormSquared): + r"""Functional class for the mean squared error.""" + + def __init__( + self, + weight: torch.Tensor | complex = 1.0, + target: torch.Tensor | None | complex = None, + dim: int | Sequence[int] | None = None, + divide_by_n: bool = True, + keepdim: bool = False, + ) -> None: + r"""Initialize MSE Functional. + + The MSE functional is given by + :math:`f: C^N -> [0, \infty), x -> 1/N \| W (x-b)\|_2^2`, + where :math:`W` is either a scalar or tensor that corresponds to a (block-) diagonal operator + that is applied to the input. The division by `N` can be disabled by setting `divide_by_n=False` + For more details also see :class:`mrpro.operators.functionals.L2NormSquared` + + Parameters + ---------- + weight + weight parameter (see above) + target + target element - often data tensor (see above) + dim + dimension(s) over which functional is reduced. + All other dimensions of `weight ( x - target)` will be treated as batch dimensions. + divide_by_n + if true, the result is scaled by the number of elements of the dimensions index by `dim` in + the tensor `weight ( x - target)`. If true, the functional is thus calculated as the mean, + else the sum. + keepdim + if true, the dimension(s) of the input indexed by dim are maintained and collapsed to singeltons, + else they are removed from the result. + + """ + super().__init__(weight=weight, target=target, dim=dim, divide_by_n=divide_by_n, keepdim=keepdim) diff --git a/src/mrpro/operators/functionals/MSEDataDiscrepancy.py b/src/mrpro/operators/functionals/MSEDataDiscrepancy.py deleted file mode 100644 index df44746ec..000000000 --- a/src/mrpro/operators/functionals/MSEDataDiscrepancy.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Mean squared error (MSE) data-discrepancy function.""" - -import torch -import torch.nn.functional as F # noqa: N812 - -from mrpro.operators.Operator import Operator - - -class MSEDataDiscrepancy(Operator[torch.Tensor, tuple[torch.Tensor]]): - """Mean Squared Error (MSE) loss function. - - This class implements the function :math:`1./N * || . - data ||_2^2`, where :math:`N` equals to the number of - elements of the tensor. - - Note: if one of data or input is complex-valued, we identify the space :math:`C^N` with :math:`R^{2N}` and - multiply the output by 2. By this, we achieve that for example :math:`MSE(1)` = :math:`MSE(1+1j*0)` = 1. - - Parameters - ---------- - data - observed data - """ - - def __init__(self, data: torch.Tensor): - """Initialize the MSE data-discrepancy operator.""" - super().__init__() - - # observed data - self.data = data - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]: - """Calculate the MSE of the input. - - Parameters - ---------- - x - tensor whose MSE with respect to the data given at initialization should be calculated - - Returns - ------- - Mean Squared Error (MSE) of input and the data - """ - if torch.is_complex(x) or torch.is_complex(self.data): - # F.mse_loss is only implemented for real tensors - # Thus, we cast both to C and then to R^2 - # and undo the division by ten twice the number of elements in mse_loss - x_r2 = torch.view_as_real(x) if torch.is_complex(x) else torch.view_as_real(x + 1j * 0) - data_r2 = ( - torch.view_as_real(self.data) if torch.is_complex(self.data) else torch.view_as_real(self.data + 1j * 0) - ) - mse = F.mse_loss(x_r2, data_r2) * 2.0 - else: # both are real - mse = F.mse_loss(x, self.data) - return (mse,) diff --git a/src/mrpro/operators/functionals/__init__.py b/src/mrpro/operators/functionals/__init__.py index 2e44c1c9e..3fe3455d7 100644 --- a/src/mrpro/operators/functionals/__init__.py +++ b/src/mrpro/operators/functionals/__init__.py @@ -1,6 +1,6 @@ from mrpro.operators.functionals.L1Norm import L1Norm from mrpro.operators.functionals.L1NormViewAsReal import L1NormViewAsReal from mrpro.operators.functionals.L2NormSquared import L2NormSquared -from mrpro.operators.functionals.MSEDataDiscrepancy import MSEDataDiscrepancy +from mrpro.operators.functionals.MSE import MSE from mrpro.operators.functionals.ZeroFunctional import ZeroFunctional -__all__ = ["L1Norm", "L1NormViewAsReal", "L2NormSquared", "MSEDataDiscrepancy", "ZeroFunctional"] +__all__ = ["L1Norm", "L1NormViewAsReal", "L2NormSquared", "MSE", "ZeroFunctional"] diff --git a/tests/operators/functionals/__init__.py b/tests/operators/functionals/__init__.py index e9d7aa091..878750b62 100644 --- a/tests/operators/functionals/__init__.py +++ b/tests/operators/functionals/__init__.py @@ -1,4 +1,5 @@ from mrpro.operators.functionals.L1NormViewAsReal import L1NormViewAsReal +from mrpro.operators.functionals.L1Norm import L1Norm from mrpro.operators.functionals.L2NormSquared import L2NormSquared -from mrpro.operators.functionals.MSEDataDiscrepancy import MSEDataDiscrepancy +from mrpro.operators.functionals.MSE import MSE from mrpro.operators.functionals.ZeroFunctional import ZeroFunctional diff --git a/tests/operators/functionals/test_functionals.py b/tests/operators/functionals/test_functionals.py index 7f33b6e81..39b16da12 100644 --- a/tests/operators/functionals/test_functionals.py +++ b/tests/operators/functionals/test_functionals.py @@ -4,7 +4,7 @@ import pytest import torch from mrpro.operators.Functional import ElementaryFunctional, ElementaryProximableFunctional -from mrpro.operators.functionals import L1Norm, L1NormViewAsReal, L2NormSquared, ZeroFunctional +from mrpro.operators.functionals import MSE, L1Norm, L1NormViewAsReal, L2NormSquared, ZeroFunctional from typing_extensions import TypedDict from tests import RandomGenerator @@ -297,6 +297,19 @@ class NumericCase(TypedDict): [[[-2.983529, -1.943529, -1.049412], [-0.108235, 1.468235, 1.971765]]] ), }, + 'MSE': { + # Generated with ODL + 'functional': MSE, + 'x': torch.tensor([[[-3.0, -2.0, -1.0], [0.0, 1.0, 2.0]]]), + 'weight': 2.0, + 'target': torch.tensor([[[0.340, 0.130, 0.230], [0.230, -1.120, -0.190]]]), + 'sigma': 0.5, + 'fx_expected': torch.tensor(17.6992), + 'prox_expected': torch.tensor([[[-1.6640, -1.1480, -0.5080], [0.0920, 0.1520, 1.1240]]]), + 'prox_convex_conj_expected': torch.tensor( + [[[-2.305455, -1.501818, -0.810909], [-0.083636, 1.134545, 1.523636]]] + ), + }, } diff --git a/tests/operators/functionals/test_mse_functional.py b/tests/operators/functionals/test_mse_functional.py deleted file mode 100644 index 3ff9b1ee8..000000000 --- a/tests/operators/functionals/test_mse_functional.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Tests for MSE-functional.""" - -import pytest -import torch -from mrpro.operators.functionals.MSEDataDiscrepancy import MSEDataDiscrepancy - - -@pytest.mark.parametrize( - ('data', 'x', 'expected_mse'), - [ - ((0.0, 0.0), (0.0, 0.0), (0.0)), # zero-tensors deliver 0-error - ((0.0 + 1j * 0, 0.0), (0.0 + 1j * 0, 0.0), (0.0)), # zero-tensors deliver 0-error; complex-valued - ((1.0, 0.0), (1.0, 0.0), (0.0)), # same tensors; both real-valued - ((1.0, 0.0), (1.0 + 1j * 0, 0.0), (0.0)), # same tensors; input complex-valued - ((1.0, 0.0), (1.0 + 1j * 1, 0.0), (0.5)), # different tensors; input complex-valued - ((1.0 + 1j * 0, 0.0), (1.0, 0.0), (0.0)), # same tensors; data complex-valued - ((1.0 + 1j * 1, 0.0), (1.0, 0.0), (0.5)), # different tensors; data complex-valued - ((1.0 + 1j * 0, 0.0), (1.0 + 1j * 0, 0.0), (0.0)), # same tensors; both complex-valued with imag part=0 - ((1.0 + 1j * 1, 0.0), (1.0 + 1j * 1, 0.0), (0.0)), # same tensors; both complex-valued with imag part>0 - ((0.0 + 1j * 1, 0.0), (0.0 + 1j * 1, 0.0), (0.0)), # same tensors; both complex-valued with real part=0 - ], -) -def test_mse_functional(data, x, expected_mse): - """Test if mse_data_discrepancy matches expected values. - - Expected values are supposed to be - 1/N*|| . - data||_2^2 - """ - - mse_op = MSEDataDiscrepancy(torch.tensor(data)) - (mse,) = mse_op(torch.tensor(x)) - torch.testing.assert_close(mse, torch.tensor(expected_mse)) From 199e2e847a6f490a6ffaa41c57072c35f183c9e9 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Thu, 14 Nov 2024 16:01:10 +0100 Subject: [PATCH 24/35] Swap order of target and weight in functionals (#535) --- src/mrpro/operators/Functional.py | 6 +++--- src/mrpro/operators/functionals/MSE.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mrpro/operators/Functional.py b/src/mrpro/operators/Functional.py index a55603818..6dbce2b6e 100644 --- a/src/mrpro/operators/Functional.py +++ b/src/mrpro/operators/Functional.py @@ -50,8 +50,8 @@ class ElementaryFunctional(Functional): def __init__( self, - weight: torch.Tensor | complex = 1.0, target: torch.Tensor | None | complex = None, + weight: torch.Tensor | complex = 1.0, dim: int | Sequence[int] | None = None, divide_by_n: bool = False, keepdim: bool = False, @@ -64,10 +64,10 @@ def __init__( Parameters ---------- - weight - weight parameter (see above) target target element - often data tensor (see above) + weight + weight parameter (see above) dim dimension(s) over which functional is reduced. All other dimensions of `weight ( x - target)` will be treated as batch dimensions. diff --git a/src/mrpro/operators/functionals/MSE.py b/src/mrpro/operators/functionals/MSE.py index 6a8166863..8b67c4ed7 100644 --- a/src/mrpro/operators/functionals/MSE.py +++ b/src/mrpro/operators/functionals/MSE.py @@ -12,8 +12,8 @@ class MSE(L2NormSquared): def __init__( self, - weight: torch.Tensor | complex = 1.0, target: torch.Tensor | None | complex = None, + weight: torch.Tensor | complex = 1.0, dim: int | Sequence[int] | None = None, divide_by_n: bool = True, keepdim: bool = False, @@ -28,10 +28,10 @@ def __init__( Parameters ---------- - weight - weight parameter (see above) target target element - often data tensor (see above) + weight + weight parameter (see above) dim dimension(s) over which functional is reduced. All other dimensions of `weight ( x - target)` will be treated as batch dimensions. From 38722bf9c9aad87bb30008d6ada488e69f239cd1 Mon Sep 17 00:00:00 2001 From: Lunin Leonid Date: Thu, 14 Nov 2024 18:14:43 +0100 Subject: [PATCH 25/35] Examples typos (#449) Co-authored-by: Felix F Zimmermann --- examples/qmri_sg_challenge_2024_t1.ipynb | 46 ++++----- examples/qmri_sg_challenge_2024_t1.py | 46 ++++----- examples/qmri_sg_challenge_2024_t2_star.ipynb | 28 +++--- examples/qmri_sg_challenge_2024_t2_star.py | 28 +++--- examples/t1_mapping_with_grad_acq.ipynb | 97 +++++++++++++++---- examples/t1_mapping_with_grad_acq.py | 60 ++++++++---- 6 files changed, 194 insertions(+), 111 deletions(-) diff --git a/examples/qmri_sg_challenge_2024_t1.ipynb b/examples/qmri_sg_challenge_2024_t1.ipynb index 98441ca8d..e5605adfb 100644 --- a/examples/qmri_sg_challenge_2024_t1.ipynb +++ b/examples/qmri_sg_challenge_2024_t1.ipynb @@ -5,7 +5,7 @@ "id": "0f82262f", "metadata": {}, "source": [ - "# QMRI Challenge ISMRM 2024 - T1 mapping" + "# QMRI Challenge ISMRM 2024 - $T_1$ mapping" ] }, { @@ -40,7 +40,7 @@ "source": [ "### Overview\n", "The dataset consists of images obtained at 10 different inversion times using a turbo spin echo sequence. Each\n", - "inversion time is saved in a separate DICOM file. In order to obtain a T1 map, we are going to:\n", + "inversion time is saved in a separate DICOM file. In order to obtain a $T_1$ map, we are going to:\n", "- download the data from Zenodo\n", "- read in the DICOM files (one for each inversion time) and combine them in an IData object\n", "- define a signal model and data loss (mean-squared error) function\n", @@ -105,7 +105,7 @@ "fig, axes = plt.subplots(1, 3, squeeze=False)\n", "for idx, ax in enumerate(axes.flatten()):\n", " ax.imshow(torch.abs(idata_multi_ti.data[idx, 0, 0, :, :]))\n", - " ax.set_title(f'TI = {idata_multi_ti.header.ti[idx]:.0f}ms')" + " ax.set_title(f'TI = {idata_multi_ti.header.ti[idx]:.3f}s')" ] }, { @@ -116,9 +116,9 @@ "### Signal model and loss function\n", "We use the model $q$\n", "\n", - "$q(TI) = M_0 (1 - e^{-TI/T1})$\n", + "$q(TI) = M_0 (1 - e^{-TI/T_1})$\n", "\n", - "with the equilibrium magnetization $M_0$, the inversion time $TI$, and $T1$. We have to keep in mind that the DICOM\n", + "with the equilibrium magnetization $M_0$, the inversion time $TI$, and $T_1$. We have to keep in mind that the DICOM\n", "images only contain the magnitude of the signal. Therefore, we need $|q(TI)|$:" ] }, @@ -162,7 +162,7 @@ "source": [ "Now we can simply combine the two into a functional to solve\n", "\n", - "$ \\min_{M_0, T1} || |q(M_0, T1, TI)| - x||_2^2$" + "$ \\min_{M_0, T_1} || |q(M_0, T_1, TI)| - x||_2^2$" ] }, { @@ -187,11 +187,11 @@ "To increase our chances of reaching the global minimum, we can ensure that our starting\n", "values are already close to the global minimum. We need a good starting point for each pixel.\n", "\n", - "One option to get a good starting point is to calculate the signal curves for a range of T1 values and then check\n", + "One option to get a good starting point is to calculate the signal curves for a range of $T_1$ values and then check\n", "for each pixel which of these signal curves fits best. This is similar to what is done for MR Fingerprinting. So we\n", "are going to:\n", - "- define a list of realistic T1 values (we call this a dictionary of T1 values)\n", - "- calculate the signal curves corresponding to each of these T1 values\n", + "- define a list of realistic $T_1$ values (we call this a dictionary of $T_1$ values)\n", + "- calculate the signal curves corresponding to each of these $T_1$ values\n", "- compare the signal curves to the signals of each voxel (we use the maximum of the dot-product as a metric of how\n", "well the signals fit to each other)" ] @@ -203,8 +203,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Define 100 T1 values between 100 and 3000 ms\n", - "t1_dictionary = torch.linspace(100, 3000, 100)\n", + "# Define 100 T1 values between 0.1 and 3.0 s\n", + "t1_dictionary = torch.linspace(0.1, 3.0, 100)\n", "\n", "# Calculate the signal corresponding to each of these T1 values. We set M0 to 1, but this is arbitrary because M0 is\n", "# just a scaling factor and we are going to normalize the signal curves.\n", @@ -227,8 +227,8 @@ "metadata": {}, "outputs": [], "source": [ - "# The image with the longest inversion time is a good approximation of the equilibrium magnetization\n", - "m0_start = torch.abs(idata_multi_ti.data[torch.argmax(idata_multi_ti.header.ti), ...])" + "# The maximum absolute value observed is a good approximation for m0\n", + "m0_start = torch.amax(torch.abs(idata_multi_ti.data), 0)" ] }, { @@ -242,11 +242,11 @@ "fig, axes = plt.subplots(1, 2, figsize=(8, 2), squeeze=False)\n", "colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n", "im = axes[0, 0].imshow(m0_start[0, 0, ...])\n", - "axes[0, 0].set_title('M0 start values')\n", + "axes[0, 0].set_title('$M_0$ start values')\n", "fig.colorbar(im, cax=colorbar_ax[0])\n", - "im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2500)\n", - "axes[0, 1].set_title('T1 start values')\n", - "fig.colorbar(im, cax=colorbar_ax[1])" + "im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2.5)\n", + "axes[0, 1].set_title('$T_1$ start values')\n", + "fig.colorbar(im, cax=colorbar_ax[1], label='s')" ] }, { @@ -266,7 +266,7 @@ "source": [ "# Hyperparameters for optimizer\n", "max_iter = 2000\n", - "lr = 1e0\n", + "lr = 1e-1\n", "\n", "# Run optimization\n", "params_result = adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr)\n", @@ -283,7 +283,7 @@ "### Visualize the final results\n", "To get an impression of how well the fit has worked, we are going to calculate the relative error between\n", "\n", - "$E_{relative} = \\sum_{TI}\\frac{|(q(M_0, T1, TI) - x)|}{|x|}$\n", + "$E_{relative} = \\sum_{TI}\\frac{|(q(M_0, T_1, TI) - x)|}{|x|}$\n", "\n", "on a voxel-by-voxel basis" ] @@ -304,11 +304,11 @@ "fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False)\n", "colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n", "im = axes[0, 0].imshow(m0[0, 0, ...])\n", - "axes[0, 0].set_title('M0')\n", + "axes[0, 0].set_title('$M_0$')\n", "fig.colorbar(im, cax=colorbar_ax[0])\n", - "im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2500)\n", - "axes[0, 1].set_title('T1')\n", - "fig.colorbar(im, cax=colorbar_ax[1])\n", + "im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2.5)\n", + "axes[0, 1].set_title('$T_1$')\n", + "fig.colorbar(im, cax=colorbar_ax[1], label='s')\n", "im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...], vmin=0, vmax=1.0)\n", "axes[0, 2].set_title('Relative error')\n", "fig.colorbar(im, cax=colorbar_ax[2])" diff --git a/examples/qmri_sg_challenge_2024_t1.py b/examples/qmri_sg_challenge_2024_t1.py index ee146cb20..d0259f267 100644 --- a/examples/qmri_sg_challenge_2024_t1.py +++ b/examples/qmri_sg_challenge_2024_t1.py @@ -1,5 +1,5 @@ # %% [markdown] -# # QMRI Challenge ISMRM 2024 - T1 mapping +# # QMRI Challenge ISMRM 2024 - $T_1$ mapping # %% # Imports @@ -22,7 +22,7 @@ # %% [markdown] # ### Overview # The dataset consists of images obtained at 10 different inversion times using a turbo spin echo sequence. Each -# inversion time is saved in a separate DICOM file. In order to obtain a T1 map, we are going to: +# inversion time is saved in a separate DICOM file. In order to obtain a $T_1$ map, we are going to: # - download the data from Zenodo # - read in the DICOM files (one for each inversion time) and combine them in an IData object # - define a signal model and data loss (mean-squared error) function @@ -53,15 +53,15 @@ fig, axes = plt.subplots(1, 3, squeeze=False) for idx, ax in enumerate(axes.flatten()): ax.imshow(torch.abs(idata_multi_ti.data[idx, 0, 0, :, :])) - ax.set_title(f'TI = {idata_multi_ti.header.ti[idx]:.0f}ms') + ax.set_title(f'TI = {idata_multi_ti.header.ti[idx]:.3f}s') # %% [markdown] # ### Signal model and loss function # We use the model $q$ # -# $q(TI) = M_0 (1 - e^{-TI/T1})$ +# $q(TI) = M_0 (1 - e^{-TI/T_1})$ # -# with the equilibrium magnetization $M_0$, the inversion time $TI$, and $T1$. We have to keep in mind that the DICOM +# with the equilibrium magnetization $M_0$, the inversion time $TI$, and $T_1$. We have to keep in mind that the DICOM # images only contain the magnitude of the signal. Therefore, we need $|q(TI)|$: # %% @@ -76,7 +76,7 @@ # %% [markdown] # Now we can simply combine the two into a functional to solve # -# $ \min_{M_0, T1} || |q(M_0, T1, TI)| - x||_2^2$ +# $ \min_{M_0, T_1} || |q(M_0, T_1, TI)| - x||_2^2$ # %% functional = mse @ model @@ -88,17 +88,17 @@ # To increase our chances of reaching the global minimum, we can ensure that our starting # values are already close to the global minimum. We need a good starting point for each pixel. # -# One option to get a good starting point is to calculate the signal curves for a range of T1 values and then check +# One option to get a good starting point is to calculate the signal curves for a range of $T_1$ values and then check # for each pixel which of these signal curves fits best. This is similar to what is done for MR Fingerprinting. So we # are going to: -# - define a list of realistic T1 values (we call this a dictionary of T1 values) -# - calculate the signal curves corresponding to each of these T1 values +# - define a list of realistic $T_1$ values (we call this a dictionary of $T_1$ values) +# - calculate the signal curves corresponding to each of these $T_1$ values # - compare the signal curves to the signals of each voxel (we use the maximum of the dot-product as a metric of how # well the signals fit to each other) # %% -# Define 100 T1 values between 100 and 3000 ms -t1_dictionary = torch.linspace(100, 3000, 100) +# Define 100 T1 values between 0.1 and 3.0 s +t1_dictionary = torch.linspace(0.1, 3.0, 100) # Calculate the signal corresponding to each of these T1 values. We set M0 to 1, but this is arbitrary because M0 is # just a scaling factor and we are going to normalize the signal curves. @@ -114,19 +114,19 @@ t1_start = rearrange(t1_dictionary[idx_best_match], '(y x)->1 1 y x', y=n_y, x=n_x) # %% -# The image with the longest inversion time is a good approximation of the equilibrium magnetization -m0_start = torch.abs(idata_multi_ti.data[torch.argmax(idata_multi_ti.header.ti), ...]) +# The maximum absolute value observed is a good approximation for m0 +m0_start = torch.amax(torch.abs(idata_multi_ti.data), 0) # %% # Visualize the starting values fig, axes = plt.subplots(1, 2, figsize=(8, 2), squeeze=False) colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] im = axes[0, 0].imshow(m0_start[0, 0, ...]) -axes[0, 0].set_title('M0 start values') +axes[0, 0].set_title('$M_0$ start values') fig.colorbar(im, cax=colorbar_ax[0]) -im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2500) -axes[0, 1].set_title('T1 start values') -fig.colorbar(im, cax=colorbar_ax[1]) +im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2.5) +axes[0, 1].set_title('$T_1$ start values') +fig.colorbar(im, cax=colorbar_ax[1], label='s') # %% [markdown] # ### Carry out fit @@ -134,7 +134,7 @@ # %% # Hyperparameters for optimizer max_iter = 2000 -lr = 1e0 +lr = 1e-1 # Run optimization params_result = adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr) @@ -146,7 +146,7 @@ # ### Visualize the final results # To get an impression of how well the fit has worked, we are going to calculate the relative error between # -# $E_{relative} = \sum_{TI}\frac{|(q(M_0, T1, TI) - x)|}{|x|}$ +# $E_{relative} = \sum_{TI}\frac{|(q(M_0, T_1, TI) - x)|}{|x|}$ # # on a voxel-by-voxel basis @@ -158,11 +158,11 @@ fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False) colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] im = axes[0, 0].imshow(m0[0, 0, ...]) -axes[0, 0].set_title('M0') +axes[0, 0].set_title('$M_0$') fig.colorbar(im, cax=colorbar_ax[0]) -im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2500) -axes[0, 1].set_title('T1') -fig.colorbar(im, cax=colorbar_ax[1]) +im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2.5) +axes[0, 1].set_title('$T_1$') +fig.colorbar(im, cax=colorbar_ax[1], label='s') im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...], vmin=0, vmax=1.0) axes[0, 2].set_title('Relative error') fig.colorbar(im, cax=colorbar_ax[2]) diff --git a/examples/qmri_sg_challenge_2024_t2_star.ipynb b/examples/qmri_sg_challenge_2024_t2_star.ipynb index 1fda8b92c..ad4122033 100644 --- a/examples/qmri_sg_challenge_2024_t2_star.ipynb +++ b/examples/qmri_sg_challenge_2024_t2_star.ipynb @@ -5,7 +5,7 @@ "id": "5efa18e9", "metadata": {}, "source": [ - "# QMRI Challenge ISMRM 2024 - T2* mapping" + "# QMRI Challenge ISMRM 2024 - $T_2^*$ mapping" ] }, { @@ -39,7 +39,7 @@ "source": [ "### Overview\n", "The dataset consists of gradient echo images obtained at 11 different echo times, each saved in a separate DICOM file.\n", - "In order to obtain a T2* map, we are going to:\n", + "In order to obtain a $T_2^*$ map, we are going to:\n", "- download the data from Zenodo\n", "- read in the DICOM files (one for each echo time) and combine them in an IData object\n", "- define a signal model (mono-exponential decay) and data loss (mean-squared error) function\n", @@ -100,6 +100,8 @@ "source": [ "te_dicom_files = data_folder.glob('**/*.dcm')\n", "idata_multi_te = IData.from_dicom_files(te_dicom_files)\n", + "# scaling the signal down to make the optimization easier\n", + "idata_multi_te.data[...] = idata_multi_te.data / 1500\n", "\n", "# Move the data to the GPU\n", "if flag_use_cuda:\n", @@ -120,7 +122,7 @@ "fig, axes = plt.subplots(1, 3, squeeze=False)\n", "for idx, ax in enumerate(axes.flatten()):\n", " ax.imshow(torch.abs(idata_multi_te.data[idx, 0, 0, :, :]).cpu())\n", - " ax.set_title(f'TE = {idata_multi_te.header.te[idx]:.0f}ms')" + " ax.set_title(f'TE = {idata_multi_te.header.te[idx]:.3f}s')" ] }, { @@ -131,9 +133,9 @@ "### Signal model and loss function\n", "We use the model $q$\n", "\n", - "$q(TE) = M_0 e^{-TE/T2^*}$\n", + "$q(TE) = M_0 e^{-TE/T_2^*}$\n", "\n", - "with the equilibrium magnetization $M_0$, the echo time $TE$, and $T2^*$" + "with the equilibrium magnetization $M_0$, the echo time $TE$, and $T_2^*$" ] }, { @@ -176,7 +178,7 @@ "source": [ "Now we can simply combine the two into a functional which will then solve\n", "\n", - "$ \\min_{M_0, T2^*} ||q(M_0, T2^*, TE) - x||_2^2$" + "$ \\min_{M_0, T_2^*} ||q(M_0, T_2^*, TE) - x||_2^2$" ] }, { @@ -207,11 +209,11 @@ "# The shortest echo time is a good approximation of the equilibrium magnetization\n", "m0_start = torch.abs(idata_multi_te.data[torch.argmin(idata_multi_te.header.te), ...])\n", "# 20 ms as a starting value for T2*\n", - "t2star_start = torch.ones(m0_start.shape, dtype=torch.float32, device=m0_start.device) * 20\n", + "t2star_start = torch.ones(m0_start.shape, dtype=torch.float32, device=m0_start.device) * 20e-3\n", "\n", "# Hyperparameters for optimizer\n", "max_iter = 20000\n", - "lr = 1e0\n", + "lr = 1e-3\n", "\n", "if flag_use_cuda:\n", " functional.cuda()\n", @@ -235,7 +237,7 @@ "### Visualize the final results\n", "To get an impression of how well the fit has worked, we are going to calculate the relative error between\n", "\n", - "$E_{relative} = \\sum_{TE}\\frac{|(q(M_0, T2^*, TE) - x)|}{|x|}$\n", + "$E_{relative} = \\sum_{TE}\\frac{|(q(M_0, T_2^*, TE) - x)|}{|x|}$\n", "\n", "on a voxel-by-voxel basis." ] @@ -257,12 +259,12 @@ "colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n", "\n", "im = axes[0, 0].imshow(m0[0, 0, ...].cpu())\n", - "axes[0, 0].set_title('M0')\n", + "axes[0, 0].set_title('$M_0$')\n", "fig.colorbar(im, cax=colorbar_ax[0])\n", "\n", - "im = axes[0, 1].imshow(t2star[0, 0, ...].cpu(), vmin=0, vmax=500)\n", - "axes[0, 1].set_title('T2*')\n", - "fig.colorbar(im, cax=colorbar_ax[1])\n", + "im = axes[0, 1].imshow(t2star[0, 0, ...].cpu(), vmin=0, vmax=5)\n", + "axes[0, 1].set_title('$T_2^*$')\n", + "fig.colorbar(im, cax=colorbar_ax[1], label='s')\n", "\n", "im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...].cpu(), vmin=0, vmax=0.1)\n", "axes[0, 2].set_title('Relative error')\n", diff --git a/examples/qmri_sg_challenge_2024_t2_star.py b/examples/qmri_sg_challenge_2024_t2_star.py index e7e28372f..a80f40754 100644 --- a/examples/qmri_sg_challenge_2024_t2_star.py +++ b/examples/qmri_sg_challenge_2024_t2_star.py @@ -1,5 +1,5 @@ # %% [markdown] -# # QMRI Challenge ISMRM 2024 - T2* mapping +# # QMRI Challenge ISMRM 2024 - $T_2^*$ mapping # %% # Imports @@ -21,7 +21,7 @@ # %% [markdown] # ### Overview # The dataset consists of gradient echo images obtained at 11 different echo times, each saved in a separate DICOM file. -# In order to obtain a T2* map, we are going to: +# In order to obtain a $T_2^*$ map, we are going to: # - download the data from Zenodo # - read in the DICOM files (one for each echo time) and combine them in an IData object # - define a signal model (mono-exponential decay) and data loss (mean-squared error) function @@ -48,6 +48,8 @@ # %% te_dicom_files = data_folder.glob('**/*.dcm') idata_multi_te = IData.from_dicom_files(te_dicom_files) +# scaling the signal down to make the optimization easier +idata_multi_te.data[...] = idata_multi_te.data / 1500 # Move the data to the GPU if flag_use_cuda: @@ -61,15 +63,15 @@ fig, axes = plt.subplots(1, 3, squeeze=False) for idx, ax in enumerate(axes.flatten()): ax.imshow(torch.abs(idata_multi_te.data[idx, 0, 0, :, :]).cpu()) - ax.set_title(f'TE = {idata_multi_te.header.te[idx]:.0f}ms') + ax.set_title(f'TE = {idata_multi_te.header.te[idx]:.3f}s') # %% [markdown] # ### Signal model and loss function # We use the model $q$ # -# $q(TE) = M_0 e^{-TE/T2^*}$ +# $q(TE) = M_0 e^{-TE/T_2^*}$ # -# with the equilibrium magnetization $M_0$, the echo time $TE$, and $T2^*$ +# with the equilibrium magnetization $M_0$, the echo time $TE$, and $T_2^*$ # %% model = MonoExponentialDecay(decay_time=idata_multi_te.header.te) @@ -83,7 +85,7 @@ # %% [markdown] # Now we can simply combine the two into a functional which will then solve # -# $ \min_{M_0, T2^*} ||q(M_0, T2^*, TE) - x||_2^2$ +# $ \min_{M_0, T_2^*} ||q(M_0, T_2^*, TE) - x||_2^2$ # %% functional = mse @ model @@ -94,11 +96,11 @@ # The shortest echo time is a good approximation of the equilibrium magnetization m0_start = torch.abs(idata_multi_te.data[torch.argmin(idata_multi_te.header.te), ...]) # 20 ms as a starting value for T2* -t2star_start = torch.ones(m0_start.shape, dtype=torch.float32, device=m0_start.device) * 20 +t2star_start = torch.ones(m0_start.shape, dtype=torch.float32, device=m0_start.device) * 20e-3 # Hyperparameters for optimizer max_iter = 20000 -lr = 1e0 +lr = 1e-3 if flag_use_cuda: functional.cuda() @@ -115,7 +117,7 @@ # ### Visualize the final results # To get an impression of how well the fit has worked, we are going to calculate the relative error between # -# $E_{relative} = \sum_{TE}\frac{|(q(M_0, T2^*, TE) - x)|}{|x|}$ +# $E_{relative} = \sum_{TE}\frac{|(q(M_0, T_2^*, TE) - x)|}{|x|}$ # # on a voxel-by-voxel basis. # %% @@ -127,12 +129,12 @@ colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] im = axes[0, 0].imshow(m0[0, 0, ...].cpu()) -axes[0, 0].set_title('M0') +axes[0, 0].set_title('$M_0$') fig.colorbar(im, cax=colorbar_ax[0]) -im = axes[0, 1].imshow(t2star[0, 0, ...].cpu(), vmin=0, vmax=500) -axes[0, 1].set_title('T2*') -fig.colorbar(im, cax=colorbar_ax[1]) +im = axes[0, 1].imshow(t2star[0, 0, ...].cpu(), vmin=0, vmax=5) +axes[0, 1].set_title('$T_2^*$') +fig.colorbar(im, cax=colorbar_ax[1], label='s') im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...].cpu(), vmin=0, vmax=0.1) axes[0, 2].set_title('Relative error') diff --git a/examples/t1_mapping_with_grad_acq.ipynb b/examples/t1_mapping_with_grad_acq.ipynb index d46eddc73..cfe252e13 100644 --- a/examples/t1_mapping_with_grad_acq.ipynb +++ b/examples/t1_mapping_with_grad_acq.ipynb @@ -5,7 +5,7 @@ "id": "83bfb574", "metadata": {}, "source": [ - "# T1 mapping from a continuous Golden radial acquisition" + "# $T_1$ mapping from a continuous Golden radial acquisition" ] }, { @@ -36,35 +36,88 @@ }, { "cell_type": "markdown", - "id": "7f7c1229", - "metadata": {}, + "id": "29eabc2a", + "metadata": { + "lines_to_next_cell": 2 + }, "source": [ "### Overview\n", "In this acquisition, a single inversion pulse is played out, followed by a continuous data acquisition with a\n", "a constant flip angle $\\alpha$. Data acquisition is carried out with a 2D Golden angle radial trajectory. The acquired\n", "data can be divided into different dynamic time frames, each corresponding to a different inversion time. A signal\n", - "model can then be fitted to this data to obtain a $T_1$ map. More information can be found in:\n", - "\n", - "Kerkering KM, Schulz-Menger J, Schaeffter T, Kolbitsch C (2023) Motion-corrected model-based reconstruction for 2D\n", - "myocardial T1 mapping, MRM 90 https://doi.org/10.1002/mrm.29699\n", + "model can then be fitted to this data to obtain a $T_1$ map.\n", "\n", + "More information can be found in:\n", + "Kerkering KM, Schulz-Menger J, Schaeffter T, Kolbitsch C (2023). Motion-corrected model-based reconstruction for 2D\n", + "myocardial $T_1$ mapping. *Magnetic Resonance in Medicine*, 90(3):1086-1100, [10.1002/mrm.29699](https://doi.org/10.1002/mrm.29699)" + ] + }, + { + "cell_type": "markdown", + "id": "2f2c110e", + "metadata": {}, + "source": [ "The number of time frames and hence the number of radial lines per time frame, can in principle be chosen arbitrarily.\n", "However, a tradeoff between image quality (more radial lines per dynamic) and\n", - "temporal resolution to accurately capture the signal behavior (fewer radial lines) needs to be found.\n", - "\n", + "temporal resolution to accurately capture the signal behavior (fewer radial lines) needs to be found." + ] + }, + { + "cell_type": "markdown", + "id": "1ed1fc05", + "metadata": {}, + "source": [ "During data acquisition, the magnetization $M_z(t)$ can be described by the signal model:\n", - " $$ M_z(t) = M_0^* + (M_0^{init} - M_0^*)e^{(-t / T_1^*)} \\quad (1) $$\n", + "\n", + "$$\n", + " M_z(t) = M_0^* + (M_0^{init} - M_0^*)e^{(-t / T_1^*)} \\quad (1)\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "8b1e3c2f", + "metadata": {}, + "source": [ "where the effective longitudinal relaxation time is given by:\n", - " $$ T_1^* = \\frac{1}{\\frac{1}{T1} - \\frac{1}{T_R} ln(cos(\\alpha))} $$\n", + "\n", + "$$\n", + " T_1^* = \\frac{1}{\\frac{1}{T_1} - \\frac{1}{T_R} \\ln(\\cos(\\alpha))}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "1c6c6616", + "metadata": {}, + "source": [ "and the steady-state magnetization is\n", - " $$ M_0^* = M_0 \\frac{T_1^*}{T_1} .$$\n", "\n", + "$$\n", + " M_0^* = M_0 \\frac{T_1^*}{T_1} .\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "52b8c555", + "metadata": {}, + "source": [ "The initial magnetization $M_0^{init}$ after an inversion pulse is $-M_0$. Nevertheless, commonly after an inversion\n", "pulse, a strong spoiler gradient is played out to remove any residual transversal magnetization due to\n", "imperfections of the inversion pulse. During the spoiler gradient, the magnetization recovers with $T_1$. Commonly,\n", "the duration of this spoiler gradient $\\Delta t$ is between 10 to 20 ms. This leads to the initial magnetization\n", - " $$ M_0^{init} = M_0(1 - 2e^{(-\\Delta t / T_1)}) .$$\n", "\n", + "$$\n", + " M_0^{init} = M_0(1 - 2e^{(-\\Delta t / T_1)}) .\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "7f7c1229", + "metadata": {}, + "source": [ "In this example, we are going to:\n", "- Reconstruct a single high quality image using all acquired radial lines.\n", "- Split the data into multiple dynamics and reconstruct these dynamic images\n", @@ -82,7 +135,7 @@ "source": [ "# Download raw data in ISMRMRD format from zenodo into a temporary directory\n", "data_folder = Path(tempfile.mkdtemp())\n", - "dataset = '10671597'\n", + "dataset = '13207352'\n", "zenodo_get.zenodo_get([dataset, '-r', 5, '-o', data_folder]) # r: retries" ] }, @@ -182,7 +235,7 @@ "id": "87260553", "metadata": {}, "source": [ - "## Estimate T1 map" + "## Estimate $T_1$ map" ] }, { @@ -223,9 +276,9 @@ "source": [ "We also need the repetition time between two RF-pulses. There is a parameter `tr` in the header, but this describes\n", "the time \"between the beginning of a pulse sequence and the beginning of the succeeding (essentially identical) pulse\n", - "sequence\" (see https://dicom.innolitics.com/ciods/mr-image/mr-image/00180080). We have one inversion pulse at the\n", - "beginning, which is never repeated and hence `tr` is the duration of the entire scan. Therefore, we have to use the\n", - "parameter `echo_spacing`, which describes the time between two gradient echoes." + "sequence\" (see [DICOM Standard Browser](https://dicom.innolitics.com/ciods/mr-image/mr-image/00180080)). We have one\n", + "inversion pulse at the beginning, which is never repeated and hence `tr` is the duration of the entire scan.\n", + "Therefore, we have to use the parameter `echo_spacing`, which describes the time between two gradient echoes." ] }, { @@ -340,7 +393,9 @@ "source": [ "Now we can simply combine the loss function, the signal model and the constraints to solve\n", "\n", - "$$ \\min_{M_0, T_1, \\alpha} || |q(M_0, T_1, \\alpha)| - x||_2^2$$" + "$$\n", + " \\min_{M_0, T_1, \\alpha} || |q(M_0, T_1, \\alpha)| - x||_2^2\n", + "$$" ] }, { @@ -406,10 +461,10 @@ "fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False)\n", "colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n", "im = axes[0, 0].imshow(m0[0, ...].abs(), cmap='gray')\n", - "axes[0, 0].set_title('M0')\n", + "axes[0, 0].set_title('$M_0$')\n", "fig.colorbar(im, cax=colorbar_ax[0])\n", "im = axes[0, 1].imshow(t1[0, ...], vmin=0, vmax=2)\n", - "axes[0, 1].set_title('T1 (s)')\n", + "axes[0, 1].set_title('$T_1$ (s)')\n", "fig.colorbar(im, cax=colorbar_ax[1])\n", "im = axes[0, 2].imshow(flip_angle[0, ...] / torch.pi * 180, vmin=0, vmax=8)\n", "axes[0, 2].set_title('Flip angle (°)')\n", diff --git a/examples/t1_mapping_with_grad_acq.py b/examples/t1_mapping_with_grad_acq.py index 3a355ec11..de8e31c43 100644 --- a/examples/t1_mapping_with_grad_acq.py +++ b/examples/t1_mapping_with_grad_acq.py @@ -1,5 +1,5 @@ # %% [markdown] -# # T1 mapping from a continuous Golden radial acquisition +# # $T_1$ mapping from a continuous Golden radial acquisition # %% # Imports @@ -25,28 +25,50 @@ # In this acquisition, a single inversion pulse is played out, followed by a continuous data acquisition with a # a constant flip angle $\alpha$. Data acquisition is carried out with a 2D Golden angle radial trajectory. The acquired # data can be divided into different dynamic time frames, each corresponding to a different inversion time. A signal -# model can then be fitted to this data to obtain a $T_1$ map. More information can be found in: -# -# Kerkering KM, Schulz-Menger J, Schaeffter T, Kolbitsch C (2023) Motion-corrected model-based reconstruction for 2D -# myocardial T1 mapping, MRM 90 https://doi.org/10.1002/mrm.29699 +# model can then be fitted to this data to obtain a $T_1$ map. # +# More information can be found in: +# Kerkering KM, Schulz-Menger J, Schaeffter T, Kolbitsch C (2023). Motion-corrected model-based reconstruction for 2D +# myocardial $T_1$ mapping. *Magnetic Resonance in Medicine*, 90(3):1086-1100, [10.1002/mrm.29699](https://doi.org/10.1002/mrm.29699) + + +# %% [markdown] # The number of time frames and hence the number of radial lines per time frame, can in principle be chosen arbitrarily. # However, a tradeoff between image quality (more radial lines per dynamic) and # temporal resolution to accurately capture the signal behavior (fewer radial lines) needs to be found. -# + +# %% [markdown] # During data acquisition, the magnetization $M_z(t)$ can be described by the signal model: -# $$ M_z(t) = M_0^* + (M_0^{init} - M_0^*)e^{(-t / T_1^*)} \quad (1) $$ +# +# $$ +# M_z(t) = M_0^* + (M_0^{init} - M_0^*)e^{(-t / T_1^*)} \quad (1) +# $$ + +# %% [markdown] # where the effective longitudinal relaxation time is given by: -# $$ T_1^* = \frac{1}{\frac{1}{T1} - \frac{1}{T_R} ln(cos(\alpha))} $$ +# +# $$ +# T_1^* = \frac{1}{\frac{1}{T_1} - \frac{1}{T_R} \ln(\cos(\alpha))} +# $$ + +# %% [markdown] # and the steady-state magnetization is -# $$ M_0^* = M_0 \frac{T_1^*}{T_1} .$$ # +# $$ +# M_0^* = M_0 \frac{T_1^*}{T_1} . +# $$ + +# %% [markdown] # The initial magnetization $M_0^{init}$ after an inversion pulse is $-M_0$. Nevertheless, commonly after an inversion # pulse, a strong spoiler gradient is played out to remove any residual transversal magnetization due to # imperfections of the inversion pulse. During the spoiler gradient, the magnetization recovers with $T_1$. Commonly, # the duration of this spoiler gradient $\Delta t$ is between 10 to 20 ms. This leads to the initial magnetization -# $$ M_0^{init} = M_0(1 - 2e^{(-\Delta t / T_1)}) .$$ # +# $$ +# M_0^{init} = M_0(1 - 2e^{(-\Delta t / T_1)}) . +# $$ + +# %% [markdown] # In this example, we are going to: # - Reconstruct a single high quality image using all acquired radial lines. # - Split the data into multiple dynamics and reconstruct these dynamic images @@ -55,7 +77,7 @@ # %% # Download raw data in ISMRMRD format from zenodo into a temporary directory data_folder = Path(tempfile.mkdtemp()) -dataset = '10671597' +dataset = '13207352' zenodo_get.zenodo_get([dataset, '-r', 5, '-o', data_folder]) # r: retries @@ -105,7 +127,7 @@ cax.set_title(f'Dynamic {idx}') # %% [markdown] -# ## Estimate T1 map +# ## Estimate $T_1$ map # %% [markdown] # ### Signal model @@ -129,9 +151,9 @@ # %% [markdown] # We also need the repetition time between two RF-pulses. There is a parameter `tr` in the header, but this describes # the time "between the beginning of a pulse sequence and the beginning of the succeeding (essentially identical) pulse -# sequence" (see https://dicom.innolitics.com/ciods/mr-image/mr-image/00180080). We have one inversion pulse at the -# beginning, which is never repeated and hence `tr` is the duration of the entire scan. Therefore, we have to use the -# parameter `echo_spacing`, which describes the time between two gradient echoes. +# sequence" (see [DICOM Standard Browser](https://dicom.innolitics.com/ciods/mr-image/mr-image/00180080)). We have one +# inversion pulse at the beginning, which is never repeated and hence `tr` is the duration of the entire scan. +# Therefore, we have to use the parameter `echo_spacing`, which describes the time between two gradient echoes. # %% if kdata_dynamic.header.echo_spacing is None: @@ -181,7 +203,9 @@ # %% [markdown] # Now we can simply combine the loss function, the signal model and the constraints to solve # -# $$ \min_{M_0, T_1, \alpha} || |q(M_0, T_1, \alpha)| - x||_2^2$$ +# $$ +# \min_{M_0, T_1, \alpha} || |q(M_0, T_1, \alpha)| - x||_2^2 +# $$ # %% functional = mse_loss @ magnitude_model_op @ constraints_op @@ -212,10 +236,10 @@ fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False) colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] im = axes[0, 0].imshow(m0[0, ...].abs(), cmap='gray') -axes[0, 0].set_title('M0') +axes[0, 0].set_title('$M_0$') fig.colorbar(im, cax=colorbar_ax[0]) im = axes[0, 1].imshow(t1[0, ...], vmin=0, vmax=2) -axes[0, 1].set_title('T1 (s)') +axes[0, 1].set_title('$T_1$ (s)') fig.colorbar(im, cax=colorbar_ax[1]) im = axes[0, 2].imshow(flip_angle[0, ...] / torch.pi * 180, vmin=0, vmax=8) axes[0, 2].set_title('Flip angle (°)') From 8d24ebbab4a09f7dd976800ec027f35bc4b99eea Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 16 Nov 2024 15:59:10 +0100 Subject: [PATCH 26/35] Fill range (#540) --- .pre-commit-config.yaml | 4 ++-- pyproject.toml | 8 +++++--- src/mrpro/utils/__init__.py | 4 +++- src/mrpro/utils/fill_range.py | 24 ++++++++++++++++++++++++ tests/utils/test_fill_range.py | 23 +++++++++++++++++++++++ 5 files changed, 57 insertions(+), 6 deletions(-) create mode 100644 src/mrpro/utils/fill_range.py create mode 100644 tests/utils/test_fill_range.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 303fd43fa..2229c194f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,10 +26,10 @@ repos: - id: typos - repo: https://github.com/fzimmermann89/check_all - rev: v1.0 + rev: v1.1 hooks: - id: check-init-all - args: [--double-quotes] + args: [--double-quotes, --fix] exclude: ^tests/ - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/pyproject.toml b/pyproject.toml index a65b1cff3..6f5d09048 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -183,12 +183,14 @@ skip-magic-trailing-comma = false [tool.typos.default] locale = "en-us" +check-filename = false [tool.typos.default.extend-words] -Reson = "Reson" # required for Proc. Intl. Soc. Mag. Reson. Med. +Reson = "Reson" # required for Proc. Intl. Soc. Mag. Reson. Med. iy = "iy" -daa = 'daa' # required for wavelet operator -gaus = 'gaus' # required for wavelet operator +daa = "daa" # required for wavelet operator +gaus = "gaus" # required for wavelet operator +arange = "arange" # torch.arange [tool.typos.files] extend-exclude = [ diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 80ef9d398..944100d54 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -1,6 +1,7 @@ import mrpro.utils.slice_profiles import mrpro.utils.typing import mrpro.utils.unit_conversion +from mrpro.utils.fill_range import fill_range_ from mrpro.utils.smap import smap from mrpro.utils.remove_repeat import remove_repeat from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop @@ -10,6 +11,7 @@ __all__ = [ "broadcast_right", + "fill_range_", "reduce_view", "remove_repeat", "slice_profiles", @@ -20,4 +22,4 @@ "unsqueeze_left", "unsqueeze_right", "zero_pad_or_crop" -] +] \ No newline at end of file diff --git a/src/mrpro/utils/fill_range.py b/src/mrpro/utils/fill_range.py new file mode 100644 index 000000000..c064c63bd --- /dev/null +++ b/src/mrpro/utils/fill_range.py @@ -0,0 +1,24 @@ +"""Fill tensor in-place along a specified dimension with increasing integers.""" + +import torch + + +def fill_range_(tensor: torch.Tensor, dim: int) -> None: + """ + Fill tensor in-place along a specified dimension with increasing integers. + + Parameters + ---------- + tensor + The tensor to be modified in-place. + + dim + The dimension along which to fill with increasing values. + """ + if not -tensor.ndim <= dim < tensor.ndim: + raise IndexError(f'Dimension {dim} is out of range for tensor with {tensor.ndim} dimensions.') + + dim = dim % tensor.ndim + shape = [s if d == dim else 1 for d, s in enumerate(tensor.shape)] + values = torch.arange(tensor.size(dim), device=tensor.device).reshape(shape) + tensor[:] = values.expand_as(tensor) diff --git a/tests/utils/test_fill_range.py b/tests/utils/test_fill_range.py new file mode 100644 index 000000000..d6dd700c4 --- /dev/null +++ b/tests/utils/test_fill_range.py @@ -0,0 +1,23 @@ +"""Tests for fill_range_""" + +import pytest +import torch +from mrpro.utils import fill_range_ + + +@pytest.mark.parametrize('dtype', [torch.float32, torch.int64], ids=['float32', 'int64']) +def test_fill_range(dtype): + """Test functionality of fill_range.""" + tensor = torch.zeros(3, 4, dtype=dtype) + fill_range_(tensor, dim=1) + expected = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]], dtype=tensor.dtype) + torch.testing.assert_close(tensor, expected) + + +def test_fill_range_dim_out_of_range(): + """Test fill_range_ with a dimension out of range.""" + tensor = torch.zeros(3, 4) + with pytest.raises(IndexError, match='Dimension 2 is out of range'): + fill_range_(tensor, dim=2) + with pytest.raises(IndexError, match='Dimension -3 is out of range'): + fill_range_(tensor, dim=-3) From de79063464f9608d01b4c414cc0f2da6d0a78560 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Thu, 21 Nov 2024 12:56:17 +0100 Subject: [PATCH 27/35] Add apply (out of place) (#547) --- src/mrpro/data/MoveDataMixin.py | 18 +++++++++++++++++ src/mrpro/data/SpatialDimension.py | 13 ++++++++++++ tests/data/test_movedatamixin.py | 23 ++++++++++++++++++++- tests/data/test_spatial_dimension.py | 30 +++++++++++++++++++++++++++- 4 files changed, 82 insertions(+), 2 deletions(-) diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index 8d977d0a6..2ac3c1a58 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -239,6 +239,24 @@ def _convert(data: T) -> T: new.apply_(_convert, memo=memo, recurse=False) return new + def apply( + self: Self, + function: Callable[[Any], Any] | None = None, + *, + recurse: bool = True, + ) -> Self: + """Apply a function to all children. Returns a new object. + + Parameters + ---------- + function + The function to apply to all fields. None is interpreted as a no-op. + recurse + If True, the function will be applied to all children that are MoveDataMixin instances. + """ + new = self.clone().apply_(function, recurse=recurse) + return new + def apply_( self: Self, function: Callable[[Any], Any] | None = None, diff --git a/src/mrpro/data/SpatialDimension.py b/src/mrpro/data/SpatialDimension.py index 46b1db89a..12f94e8a6 100644 --- a/src/mrpro/data/SpatialDimension.py +++ b/src/mrpro/data/SpatialDimension.py @@ -18,6 +18,7 @@ VectorTypes = torch.Tensor ScalarTypes = int | float T = TypeVar('T', torch.Tensor, int, float) + # Covariant types, as SpatialDimension is a Container # and we want, for example, SpatialDimension[int] to also be a SpatialDimension[float] T_co = TypeVar('T_co', torch.Tensor, int, float, covariant=True) @@ -108,6 +109,7 @@ def from_array_zyx( return SpatialDimension(z, y, x) + # This function is mainly for type hinting and docstring def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self: """Apply a function to each z, y, x (in-place). @@ -118,6 +120,17 @@ def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self: """ return super(SpatialDimension, self).apply_(function) + # This function is mainly for type hinting and docstring + def apply(self, function: Callable[[T], T] | None = None, **_) -> Self: + """Apply a function to each z, y, x (returning a new object). + + Parameters + ---------- + function + function to apply + """ + return super(SpatialDimension, self).apply(function) + @property def zyx(self) -> tuple[T_co, T_co, T_co]: """Return a z,y,x tuple.""" diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py index 3feb091de..c92e12b2a 100644 --- a/tests/data/test_movedatamixin.py +++ b/tests/data/test_movedatamixin.py @@ -207,7 +207,7 @@ def testchild(attribute, expected_dtype): assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared' -def test_movedatamixin_apply(): +def test_movedatamixin_apply_(): """Tests apply_ method of MoveDataMixin.""" data = B() # make one of the parameters shared to test memo behavior @@ -223,3 +223,24 @@ def multiply_by_2(obj): torch.testing.assert_close(data.floattensor, original.floattensor * 2) torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2 * 2) assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared' + + +def test_movedatamixin_apply(): + """Tests apply method of MoveDataMixin.""" + data = B() + # make one of the parameters shared to test memo behavior + data.child.floattensor2 = data.child.floattensor + original = data.clone() + + def multiply_by_2(obj): + if isinstance(obj, torch.Tensor): + return obj * 2 + return obj + + new = data.apply(multiply_by_2) + torch.testing.assert_close(data.floattensor, original.floattensor) + torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2) + torch.testing.assert_close(new.floattensor, original.floattensor * 2) + torch.testing.assert_close(new.child.floattensor2, original.child.floattensor2 * 2) + assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared' + assert new is not data, 'new object should be different from the original' diff --git a/tests/data/test_spatial_dimension.py b/tests/data/test_spatial_dimension.py index afafece04..61fd127df 100644 --- a/tests/data/test_spatial_dimension.py +++ b/tests/data/test_spatial_dimension.py @@ -93,7 +93,7 @@ def test_spatial_dimension_broadcasting(): def test_spatial_dimension_apply_(): - """Test apply_ (inplace)""" + """Test apply_ (in place)""" def conversion(x: torch.Tensor) -> torch.Tensor: assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor' @@ -115,6 +115,34 @@ def conversion(x: torch.Tensor) -> torch.Tensor: assert torch.equal(spatial_dimension_inplace.z, z) +def test_spatial_dimension_apply(): + """Test apply (out of place)""" + + def conversion(x: torch.Tensor) -> torch.Tensor: + assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor' + return x.swapaxes(0, 1).square() + + xyz = RandomGenerator(0).float32_tensor((1, 2, 3)) + spatial_dimension = SpatialDimension.from_array_xyz(xyz.numpy()) + spatial_dimension_outofplace = spatial_dimension.apply(conversion) + + assert spatial_dimension_outofplace is not spatial_dimension + + assert isinstance(spatial_dimension_outofplace.x, torch.Tensor) + assert isinstance(spatial_dimension_outofplace.y, torch.Tensor) + assert isinstance(spatial_dimension_outofplace.z, torch.Tensor) + + x, y, z = conversion(xyz).unbind(-1) + assert torch.equal(spatial_dimension_outofplace.x, x) + assert torch.equal(spatial_dimension_outofplace.y, y) + assert torch.equal(spatial_dimension_outofplace.z, z) + + x, y, z = xyz.unbind(-1) # original should be unmodified + assert torch.equal(spatial_dimension.x, x) + assert torch.equal(spatial_dimension.y, y) + assert torch.equal(spatial_dimension.z, z) + + def test_spatial_dimension_zyx(): """Test the zyx tuple property""" z, y, x = (2, 3, 4) From bf2859bf537833232015fe165cd4fc6fc52c8dc3 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 26 Nov 2024 16:39:37 +0100 Subject: [PATCH 28/35] Fix DCF and FourierOp edge cases (#553) Our FourierOp did not work for image_size < nufft_numpoint Our DCF voronoi assumed that the k dimension requiring vorinoi did not contain singleton dimensions. --- src/mrpro/data/DcfData.py | 2 +- src/mrpro/operators/FourierOp.py | 6 +++--- tests/data/test_dcf_data.py | 16 +++++++++++++++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/mrpro/data/DcfData.py b/src/mrpro/data/DcfData.py index 62726744d..3db7f6772 100644 --- a/src/mrpro/data/DcfData.py +++ b/src/mrpro/data/DcfData.py @@ -56,7 +56,7 @@ def from_traj_voronoi(cls, traj: KTrajectory) -> Self: if ks_needing_voronoi: # Handle full dimensions needing voronoi - dcfs.append(smap(dcf_2d3d_voronoi, torch.stack(list(ks_needing_voronoi), -4), 4)) + dcfs.append(smap(dcf_2d3d_voronoi, torch.stack(torch.broadcast_tensors(*ks_needing_voronoi), -4), 4)) if dcfs: # Multiply all dcfs together diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index a3e81aba7..2e2087f78 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -108,17 +108,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): # Broadcast shapes not always needed but also does not hurt omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega] self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction - + numpoints = [min(img_size, nufft_numpoints) for img_size in self._nufft_im_size] self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft( im_size=self._nufft_im_size, grid_size=grid_size, - numpoints=nufft_numpoints, + numpoints=numpoints, kbwidth=nufft_kbwidth, ) self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint( im_size=self._nufft_im_size, grid_size=grid_size, - numpoints=nufft_numpoints, + numpoints=numpoints, kbwidth=nufft_kbwidth, ) else: diff --git a/tests/data/test_dcf_data.py b/tests/data/test_dcf_data.py index 72761739c..314ad2ba2 100644 --- a/tests/data/test_dcf_data.py +++ b/tests/data/test_dcf_data.py @@ -5,6 +5,8 @@ from einops import repeat from mrpro.data import DcfData, KTrajectory +from tests import RandomGenerator + def example_traj_rpe(n_kr, n_ka, n_k0, broadcast=True): """Create RPE trajectory with uniform angular gap.""" @@ -17,7 +19,7 @@ def example_traj_rpe(n_kr, n_ka, n_k0, broadcast=True): return trajectory -def example_traj_spiral_2d(n_kr, n_ki, n_ka, broadcast=True) -> KTrajectory: +def example_traj_spiral_2d(n_kr: int, n_ki: int, n_ka: int, broadcast: bool = True) -> KTrajectory: """Create 2D spiral trajectory with n_kr points along each spiral arm, n_ki turns per spiral arm and n_ka spiral arms.""" ang = repeat(torch.linspace(0, 2 * torch.pi * n_ki, n_kr), 'k0 -> other k2 k1 k0', other=1, k2=1, k1=1) @@ -82,3 +84,15 @@ def test_dcf_rpe_traj_voronoi_cuda(n_kr, n_ka, n_k0): trajectory = example_traj_rpe(n_kr, n_ka, n_k0) dcf = DcfData.from_traj_voronoi(trajectory.cuda()) assert dcf.data.is_cuda + + +def test_dcf_broadcast(): + """Test broadcasting within voronoi dcf calculation.""" + rng = RandomGenerator(0) + # kx and ky force voronoi calculation and need to be broadcasted + kx = rng.float32_tensor((1, 1, 4, 4)) + ky = rng.float32_tensor((1, 4, 1, 4)) + kz = torch.zeros(1, 1, 1, 1) + trajectory = KTrajectory(kz, ky, kx) + dcf = DcfData.from_traj_voronoi(trajectory) + assert dcf.data.shape == trajectory.broadcasted_shape From 41111461575f60b39c6b97fe92af24dbef782ed2 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 26 Nov 2024 16:43:30 +0100 Subject: [PATCH 29/35] Fix gram and adjoint of AdjointLinearOperator (#541) --- src/mrpro/operators/LinearOperator.py | 7 +------ tests/operators/test_operators.py | 6 ++++++ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/mrpro/operators/LinearOperator.py b/src/mrpro/operators/LinearOperator.py index 029089f5d..74d3bafdd 100644 --- a/src/mrpro/operators/LinearOperator.py +++ b/src/mrpro/operators/LinearOperator.py @@ -415,9 +415,4 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: @property def H(self) -> LinearOperator: # noqa: N802 """Adjoint of adjoint operator, i.e. original LinearOperator.""" - return self.operator - - @property - def gram(self) -> LinearOperator: - """Gram operator.""" - return self._operator.gram.H + return self._operator diff --git a/tests/operators/test_operators.py b/tests/operators/test_operators.py index 378060b74..7c1712cdd 100644 --- a/tests/operators/test_operators.py +++ b/tests/operators/test_operators.py @@ -337,6 +337,12 @@ def test_sum_operator_multiple_adjoint(): dotproduct_adjointness_test(linear_op_sum, u, v) +def test_adjoint_of_adjoint(): + """Test that the adjoint of the adjoint is the original operator""" + a = DummyLinearOperator(RandomGenerator(7).complex64_tensor((3, 10))) + assert a.H.H is a + + def test_gram_shortcuts(): """Test that .gram for composition and scalar multiplication results in shortcuts.""" From 82801e6528352e6ef9af6a4db3b04b740e6485da Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 26 Nov 2024 17:02:23 +0100 Subject: [PATCH 30/35] Release v0.241126 (#561) --- src/mrpro/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/VERSION b/src/mrpro/VERSION index c60039027..1e8c1c10e 100644 --- a/src/mrpro/VERSION +++ b/src/mrpro/VERSION @@ -1 +1 @@ -0.241112 +0.241126 From b5a3b5b4ecb0c904fdfb79ddce911bf1c1ba8b96 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Dec 2024 08:30:28 +0100 Subject: [PATCH 31/35] [pre-commit] pre-commit autoupdate (#573) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- src/mrpro/data/AcqInfo.py | 2 +- src/mrpro/data/MoveDataMixin.py | 2 +- src/mrpro/utils/slice_profiles.py | 2 +- src/mrpro/utils/typing.py | 2 +- src/mrpro/utils/unit_conversion.py | 12 ++++++------ 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2229c194f..3920fb5ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,14 +14,14 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.2 + rev: v0.8.1 hooks: - id: ruff # linter args: [--fix] - id: ruff-format # formatter - repo: https://github.com/crate-ci/typos - rev: v1.27.0 + rev: typos-dict-v0.11.37 hooks: - id: typos diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index f5d677f97..b11e07d86 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -188,7 +188,7 @@ def tensor(data: np.ndarray) -> torch.Tensor: data = data.astype(np.int32) case np.uint32 | np.uint64: data = data.astype(np.int64) - # Remove any uncessary dimensions + # Remove any unnecessary dimensions return torch.tensor(np.squeeze(data)) def tensor_2d(data: np.ndarray) -> torch.Tensor: diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index 2ac3c1a58..99bcb3df5 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -121,7 +121,7 @@ def parse1( ) -> parsedType: return device, dtype, non_blocking, copy, memory_format - if args and isinstance(args[0], torch.Tensor) or 'tensor' in kwargs: + if (args and isinstance(args[0], torch.Tensor)) or 'tensor' in kwargs: # overload 3 ("tensor" specifies the dtype and device) device, dtype, non_blocking, copy, memory_format = parse3(*args, **kwargs) elif args and isinstance(args[0], torch.dtype): diff --git a/src/mrpro/utils/slice_profiles.py b/src/mrpro/utils/slice_profiles.py index 0e5fafc1f..8eaf95311 100644 --- a/src/mrpro/utils/slice_profiles.py +++ b/src/mrpro/utils/slice_profiles.py @@ -8,7 +8,7 @@ import torch from torch import Tensor -__all__ = ['SliceProfileBase', 'SliceGaussian', 'SliceSmoothedRectangular', 'SliceInterpolate'] +__all__ = ['SliceGaussian', 'SliceInterpolate', 'SliceProfileBase', 'SliceSmoothedRectangular'] class SliceProfileBase(abc.ABC, torch.nn.Module): diff --git a/src/mrpro/utils/typing.py b/src/mrpro/utils/typing.py index f90e18dab..96ad34d07 100644 --- a/src/mrpro/utils/typing.py +++ b/src/mrpro/utils/typing.py @@ -28,4 +28,4 @@ NestedSequence: TypeAlias = Any NumpyIndexerType: TypeAlias = Any -__all__ = ['TorchIndexerType', 'NumpyIndexerType', 'NestedSequence'] +__all__ = ['NestedSequence', 'NumpyIndexerType', 'TorchIndexerType'] diff --git a/src/mrpro/utils/unit_conversion.py b/src/mrpro/utils/unit_conversion.py index 0115bed47..5a1b5aaae 100644 --- a/src/mrpro/utils/unit_conversion.py +++ b/src/mrpro/utils/unit_conversion.py @@ -6,15 +6,15 @@ import torch __all__ = [ - 'ms_to_s', - 's_to_ms', - 'mm_to_m', - 'm_to_mm', + 'GYROMAGNETIC_RATIO_PROTON', 'deg_to_rad', - 'rad_to_deg', 'lamor_frequency_to_magnetic_field', + 'm_to_mm', 'magnetic_field_to_lamor_frequency', - 'GYROMAGNETIC_RATIO_PROTON', + 'mm_to_m', + 'ms_to_s', + 'rad_to_deg', + 's_to_ms', ] GYROMAGNETIC_RATIO_PROTON = 42.58 * 1e6 From b80141febc17f1d561c2ec0842c577bffdab0459 Mon Sep 17 00:00:00 2001 From: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> Date: Mon, 9 Dec 2024 09:31:09 +0100 Subject: [PATCH 32/35] Fix trajectory scaling in KTrajectoryPulseq (#551) Co-authored-by: Felix F Zimmermann --- .../data/traj_calculators/KTrajectoryPulseq.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py index 7c843572d..598aa2184 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryPulseq.py @@ -54,13 +54,18 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape: raise ValueError('We currently only support constant number of samples') n_k0 = int(n_samples.item()) - def reshape_pulseq_traj(k_traj: torch.Tensor, encoding_size: int): - k_traj *= encoding_size / (2 * torch.max(torch.abs(k_traj))) + def rescale_and_reshape_traj(k_traj: torch.Tensor, encoding_size: int): + if encoding_size > 1 and torch.max(torch.abs(k_traj)) > 0: + k_traj = k_traj * encoding_size / (2 * torch.max(torch.abs(k_traj))) + else: + # We force k_traj to be 0 if encoding_size = 1. This is typically the case for kz in 2D sequences. + # However, it happens that seq.calculate_kspace() returns values != 0 (numerical noise) in such cases. + k_traj = torch.zeros_like(k_traj) return rearrange(k_traj, '(other k0) -> other k0', k0=n_k0) # rearrange k-space trajectory to match MRpro convention - kx = reshape_pulseq_traj(k_traj_adc[0], kheader.encoding_matrix.x) - ky = reshape_pulseq_traj(k_traj_adc[1], kheader.encoding_matrix.y) - kz = reshape_pulseq_traj(k_traj_adc[2], kheader.encoding_matrix.z) + kx = rescale_and_reshape_traj(k_traj_adc[0], kheader.encoding_matrix.x) + ky = rescale_and_reshape_traj(k_traj_adc[1], kheader.encoding_matrix.y) + kz = rescale_and_reshape_traj(k_traj_adc[2], kheader.encoding_matrix.z) return KTrajectoryRawShape(kz, ky, kx, self.repeat_detection_tolerance) From 31756e37e2dbd0875b1665bed8d6325246ae2627 Mon Sep 17 00:00:00 2001 From: Lunin Leonid Date: Tue, 10 Dec 2024 09:37:45 +0100 Subject: [PATCH 33/35] Add python 3.10 to badge and fix project keywords (#550) Co-authored-by: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> --- README.md | 2 +- pyproject.toml | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 422c4ef4d..c02b80deb 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # MRpro -![Python](https://img.shields.io/badge/python-3.11%20%7C%203.12-blue) +![Python](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) ![Coverage Bagde](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/ckolbPTB/48e334a10caf60e6708d7c712e56d241/raw/coverage.json) diff --git a/pyproject.toml b/pyproject.toml index 6f5d09048..70db1cafd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,20 @@ description = "MR image reconstruction and processing package specifically devel readme = "README.md" requires-python = ">=3.10,<3.14" dynamic = ["version"] -keywords = ["MRI, reconstruction, processing, PyTorch"] +keywords = ["MRI", + "qMRI", + "medical imaging", + "physics-informed learning", + "model-based reconstruction", + "quantitative", + "signal models", + "machine learning", + "deep learning", + "reconstruction", + "processing", + "Pulseq", + "PyTorch", +] authors = [ { name = "MRpro Team", email = "info@emerpro.de" }, { name = "Christoph Kolbitsch", email = "christoph.kolbitsch@ptb.de" }, From e15faa1673e3eb0789c2ac96e0a8892776145f5b Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 10 Dec 2024 09:41:48 +0100 Subject: [PATCH 34/35] Add memory efficient reshape of broadcasted tensors (#557) --- src/mrpro/utils/__init__.py | 3 +- src/mrpro/utils/reshape.py | 101 ++++++++++++++++++++++++++++++++++++ tests/utils/test_reshape.py | 34 +++++++++++- 3 files changed, 136 insertions(+), 2 deletions(-) diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 944100d54..b6e3c77ef 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -6,7 +6,7 @@ from mrpro.utils.remove_repeat import remove_repeat from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop from mrpro.utils.split_idx import split_idx -from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view +from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted import mrpro.utils.unit_conversion __all__ = [ @@ -14,6 +14,7 @@ "fill_range_", "reduce_view", "remove_repeat", + "reshape_broadcasted", "slice_profiles", "smap", "split_idx", diff --git a/src/mrpro/utils/reshape.py b/src/mrpro/utils/reshape.py index 31d495afd..0b208381b 100644 --- a/src/mrpro/utils/reshape.py +++ b/src/mrpro/utils/reshape.py @@ -1,6 +1,8 @@ """Tensor reshaping utilities.""" from collections.abc import Sequence +from functools import lru_cache +from math import prod import torch @@ -99,3 +101,102 @@ def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torc for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True)) ] return torch.as_strided(x, newsize, stride) + + +@lru_cache +def _reshape_idx(old_shape: tuple[int, ...], new_shape: tuple[int, ...], old_stride: tuple[int, ...]) -> list[slice]: + """Get reshape reduce index (Cached helper function for reshape_broadcasted). + + This function tries to group axes from new_shape and old_shape into the smallest groups that have + the same number of elements, starting from the right. + If all axes of old shape of a group are stride=0 dimensions, we can reduce them. + + Example: + old_shape = (30, 2, 2, 3) + new_shape = (6, 5, 4, 3) + Will results in the groups (starting from the right): + - old: 3 new: 3 + - old: 2, 2 new: 4 + - old: 30 new: 6, 5 + Only the "old" groups are important. + If all axes that are grouped together in an "old" group are stride 0 (=broadcasted) + we can collapse them to singleton dimensions. + This function returns the indexer that either collapses dimensions to singleton or keeps all + elements, i.e. the slices in the returned list are all either slice(1) or slice(None). + """ + idx = [] + pointer_old, pointer_new = len(old_shape) - 1, len(new_shape) - 1 # start from the right + while pointer_old >= 0: + product_new, product_old = 1, 1 # the number of elements in the current "new" and "old" group + group: list[int] = [] + while product_old != product_new or not group: + if product_old <= product_new: + # increase "old" group + product_old *= old_shape[pointer_old] + group.append(pointer_old) + pointer_old -= 1 + else: + # increase "new" group + # we don't need to track the new group, the number of elemeents covered. + product_new *= new_shape[pointer_new] + pointer_new -= 1 + # we found a group. now we need to decide what to do. + if all(old_stride[d] == 0 for d in group): + # all dimensions are broadcasted + # -> reduce to singleton + idx.extend([slice(1)] * len(group)) + else: + # preserve dimension + idx.extend([slice(None)] * len(group)) + idx = idx[::-1] # we worked right to left, but our index should be left to right + return idx + + +def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor: + """Reshape a tensor while preserving broadcasted (stride 0) dimensions where possible. + + Parameters + ---------- + tensor + The input tensor to reshape. + shape + The target shape for the tensor. One of the values can be `-1` and its size will be inferred. + + Returns + ------- + A tensor reshaped to the target shape, preserving broadcasted dimensions where feasible. + + """ + try: + # if we can view the tensor directly, it will preserve broadcasting + return tensor.view(shape) + except RuntimeError: + # we cannot do a view, we need to do more work: + + # -1 means infer size, i.e. the remaining elements of the input not already covered by the other axes. + negative_ones = shape.count(-1) + size = tensor.shape.numel() + if not negative_ones: + if prod(shape) != size: + # use same exception as pytorch + raise RuntimeError(f"shape '{list(shape)}' is invalid for input of size {size}") from None + elif negative_ones > 1: + raise RuntimeError('only one dimension can be inferred') from None + elif negative_ones == 1: + # we need to figure out the size of the "-1" dimension + known_size = -prod(shape) # negative, is it includes the -1 + if size % known_size: + # non integer result. no possible size of the -1 axis exists. + raise RuntimeError(f"shape '{list(shape)}' is invalid for input of size {size}") from None + shape = tuple(size // known_size if s == -1 else s for s in shape) + + # most of the broadcasted dimensions can be preserved: only dimensions that are joined with non + # broadcasted dimensions can not be preserved and must be made contiguous. + # all dimensions that can be preserved as broadcasted are first collapsed to singleton, + # such that contiguous does not create copies along these axes. + idx = _reshape_idx(tensor.shape, shape, tensor.stride()) + # make contiguous only in dimensions in which broadcasting cannot be preserved + semicontiguous = tensor[idx].contiguous() + # finally, we can expand the broadcasted dimensions to the requested shape + semicontiguous = semicontiguous.expand(tensor.shape) + return semicontiguous.view(shape) diff --git a/tests/utils/test_reshape.py b/tests/utils/test_reshape.py index dd57b8feb..3a2cc0cf5 100644 --- a/tests/utils/test_reshape.py +++ b/tests/utils/test_reshape.py @@ -1,7 +1,8 @@ """Tests for reshaping utilities.""" +import pytest import torch -from mrpro.utils import broadcast_right, reduce_view, unsqueeze_left, unsqueeze_right +from mrpro.utils import broadcast_right, reduce_view, reshape_broadcasted, unsqueeze_left, unsqueeze_right from tests import RandomGenerator @@ -51,3 +52,34 @@ def test_reduce_view(): reduced_one_pos = reduce_view(tensor, 0) assert reduced_one_pos.shape == (1, 2, 3, 4, 5, 6) assert torch.equal(reduced_one_pos.expand_as(tensor), tensor) + + +@pytest.mark.parametrize( + ('shape', 'expand_shape', 'permute', 'final_shape', 'expected_stride'), + [ + ((1, 2, 3, 1, 1), (1, 2, 3, 4, 5), (0, 2, 1, 3, 4), (1, 6, 2, 2, 5), (6, 1, 0, 0, 0)), + ((1, 2, 1), (100, 2, 2), (0, 1, 2), (100, 4), (0, 1)), + ((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 0, 1), (1, 2, 6, 10, 1), (0, 0, 0, 0, 0)), + ((1, 2, 3), (1, -1, 3), (0, 1, 2), (6,), (1,)), + ], +) +def test_reshape_broadcasted(shape, expand_shape, permute, final_shape, expected_stride): + """Test reshape_broadcasted""" + rng = RandomGenerator(0) + tensor = rng.float32_tensor(shape).expand(*expand_shape).permute(*permute) + reshaped = reshape_broadcasted(tensor, *final_shape) + expected_values = tensor.reshape(*final_shape) + assert reshaped.shape == expected_values.shape + assert reshaped.stride() == expected_stride + assert torch.equal(reshaped, expected_values) + + +def test_reshape_broadcasted_fail(): + """Test reshape_broadcasted with invalid input""" + a = torch.ones(2) + with pytest.raises(RuntimeError, match='invalid'): + reshape_broadcasted(a, 3) + with pytest.raises(RuntimeError, match='invalid'): + reshape_broadcasted(a, -1, -3) + with pytest.raises(RuntimeError, match='only one dimension'): + reshape_broadcasted(a, -1, -1) From 28038b11d852fb01df76e33f220bac5cb2de7d4c Mon Sep 17 00:00:00 2001 From: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> Date: Tue, 10 Dec 2024 17:31:22 +0100 Subject: [PATCH 35/35] Release v0.241210 (#576) --- src/mrpro/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/VERSION b/src/mrpro/VERSION index 1e8c1c10e..bcc11aaab 100644 --- a/src/mrpro/VERSION +++ b/src/mrpro/VERSION @@ -1 +1 @@ -0.241126 +0.241210