From dcb5e3407a32e2cbd9085ffbe4ce9430f8bacad0 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 9 Nov 2024 15:26:16 +0100 Subject: [PATCH 01/14] Add PCA-based compression operator (#181) Co-authored-by: Christoph Kolbitsch --- src/mrpro/operators/PCACompressionOp.py | 85 ++++++++++++++++++++++ src/mrpro/operators/__init__.py | 2 + tests/operators/test_pca_compression_op.py | 50 +++++++++++++ 3 files changed, 137 insertions(+) create mode 100644 src/mrpro/operators/PCACompressionOp.py create mode 100644 tests/operators/test_pca_compression_op.py diff --git a/src/mrpro/operators/PCACompressionOp.py b/src/mrpro/operators/PCACompressionOp.py new file mode 100644 index 000000000..ace625ea8 --- /dev/null +++ b/src/mrpro/operators/PCACompressionOp.py @@ -0,0 +1,85 @@ +"""PCA Compression Operator.""" + +import einops +import torch +from einops import repeat + +from mrpro.operators.LinearOperator import LinearOperator + + +class PCACompressionOp(LinearOperator): + """PCA based compression operator.""" + + def __init__( + self, + data: torch.Tensor, + n_components: int, + ): + """Construct a PCA based compression operator. + + The operator carries out an SVD followed by a threshold of the n_components largest values along the last + dimension of a data with shape (*other, joint_dim, compression_dim). A single SVD is carried out for everything + along joint_dim. Other are batch dimensions. + + Consider combining this operator with :class:`mrpro.operators.RearrangeOp` to make sure the data is + in the correct shape before applying. + + Parameters + ---------- + data + Data of shape (*other, joint_dim, compression_dim) to be used to find the principal components. + n_components + Number of principal components to keep along the compression_dim. + """ + super().__init__() + # different compression matrices along the *other dimensions + data = data - data.mean(-1, keepdim=True) + correlation = einops.einsum(data, data.conj(), '... joint comp1, ... joint comp2 -> ... comp1 comp2') + _, _, v = torch.svd(correlation) + # add joint_dim along which the the compression is the same + v = repeat(v, '... comp1 comp2 -> ... joint_dim comp1 comp2', joint_dim=1) + self.register_buffer('_compression_matrix', v[..., :n_components, :].clone()) + + def forward(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the compression to the data. + + Parameters + ---------- + data + data to be compressed of shape (*other, joint_dim, compression_dim) + + Returns + ------- + compressed data of shape (*other, joint_dim, n_components) + """ + try: + result = (self._compression_matrix @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix {tuple(self._compression_matrix.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) + + def adjoint(self, data: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint compression to the data. + + Parameters + ---------- + data + compressed data of shape (*other, joint_dim, n_components) + + Returns + ------- + expanded data of shape (*other, joint_dim, compression_dim) + """ + try: + result = (self._compression_matrix.mH @ data.unsqueeze(-1)).squeeze(-1) + except RuntimeError as e: + raise RuntimeError( + 'Shape mismatch in adjoint Compression: ' + f'Matrix^H {tuple(self._compression_matrix.mH.shape)} ' + f'cannot be multiplied with Data {tuple(data.shape)}.' + ) from e + return (result,) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index b8c16ebfe..c22f386cd 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -14,6 +14,7 @@ from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp +from mrpro.operators.PCACompressionOp import PCACompressionOp from mrpro.operators.PhaseOp import PhaseOp from mrpro.operators.ProximableFunctionalSeparableSum import ProximableFunctionalSeparableSum from mrpro.operators.SensitivityOp import SensitivityOp @@ -41,6 +42,7 @@ "MagnitudeOp", "MultiIdentityOp", "Operator", + "PCACompressionOp", "PhaseOp", "ProximableFunctional", "ProximableFunctionalSeparableSum", diff --git a/tests/operators/test_pca_compression_op.py b/tests/operators/test_pca_compression_op.py new file mode 100644 index 000000000..a600bcecb --- /dev/null +++ b/tests/operators/test_pca_compression_op.py @@ -0,0 +1,50 @@ +"""Tests for PCA Compression Operator.""" + +import pytest +from mrpro.operators import PCACompressionOp + +from tests import RandomGenerator +from tests.helper import dotproduct_adjointness_test + + +@pytest.mark.parametrize( + ('init_data_shape', 'input_shape', 'n_components'), + [ + ((40, 10), (100, 10), 6), + ((40, 10), (3, 4, 5, 100, 10), 3), + ((3, 4, 40, 10), (3, 4, 100, 10), 6), + ((3, 4, 40, 10), (7, 3, 4, 100, 10), 3), + ], +) +def test_pca_compression_op_adjoint(init_data_shape, input_shape, n_components): + """Test adjointness of PCA Compression Op.""" + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + u = generator.complex64_tensor(input_shape) + output_shape = (*input_shape[:-1], n_components) + v = generator.complex64_tensor(output_shape) + + # Create operator and apply + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=n_components) + dotproduct_adjointness_test(pca_comp_op, u, v) + + +def test_pca_compression_op_wrong_shapes(): + """Test if Operator raises error if shape mismatch.""" + init_data_shape = (10, 6) + input_shape = (100, 3) + + # Create test data + generator = RandomGenerator(seed=0) + data_to_calculate_compression_matrix_from = generator.complex64_tensor(init_data_shape) + input_data = generator.complex64_tensor(input_shape) + + pca_comp_op = PCACompressionOp(data=data_to_calculate_compression_matrix_from, n_components=2) + + with pytest.raises(RuntimeError, match='Matrix'): + pca_comp_op(input_data) + + with pytest.raises(RuntimeError, match='Matrix.H'): + pca_comp_op.adjoint(input_data) From c268ad25a429f93d3bb33df6cb2bf0ffc06c4759 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Sat, 9 Nov 2024 16:11:57 +0100 Subject: [PATCH 02/14] Fix CartesianSamplingOp (#483) --- src/mrpro/operators/CartesianSamplingOp.py | 10 ++++++---- tests/operators/test_cartesian_sampling_op.py | 17 +++++++++++++++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 64068a5d1..7a51924b1 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -47,26 +47,28 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> kx_idx = ktraj_tensor[-1, ...].round().to(dtype=torch.int64) + sorted_grid_shape.x // 2 else: sorted_grid_shape.x = ktraj_tensor.shape[-1] - kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k1 k2 k0', other=1, k2=1, k1=1) + kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k2 k1 k0', other=1, k2=1, k1=1) if traj_type_kzyx[-2] == TrajType.ONGRID: # ky ky_idx = ktraj_tensor[-2, ...].round().to(dtype=torch.int64) + sorted_grid_shape.y // 2 else: sorted_grid_shape.y = ktraj_tensor.shape[-2] - ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k1 k2 k0', other=1, k2=1, k0=1) + ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k2 k1 k0', other=1, k2=1, k0=1) if traj_type_kzyx[-3] == TrajType.ONGRID: # kz kz_idx = ktraj_tensor[-3, ...].round().to(dtype=torch.int64) + sorted_grid_shape.z // 2 else: sorted_grid_shape.z = ktraj_tensor.shape[-3] - kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k1 k2 k0', other=1, k1=1, k0=1) + kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k2 k1 k0', other=1, k1=1, k0=1) # 1D indices into a flattened tensor. kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)') self.register_buffer('_fft_idx', kidx) # we can skip the indexing if the data is already sorted - self._needs_indexing = not torch.all(torch.diff(kidx) == 1) + self._needs_indexing = ( + not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + ) self._trajectory_shape = traj.broadcasted_shape self._sorted_grid_shape = sorted_grid_shape diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 6a1120e79..7caa13e7c 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -2,6 +2,7 @@ import pytest import torch +from einops import rearrange from mrpro.data import KTrajectory, SpatialDimension from mrpro.operators import CartesianSamplingOp @@ -59,6 +60,9 @@ def test_cart_sampling_op_data_match(): 'regular_undersampling', 'random_undersampling', 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', ], ) def test_cart_sampling_op_fwd_adj(sampling): @@ -70,8 +74,8 @@ def test_cart_sampling_op_fwd_adj(sampling): nky = (2, 1, 40, 1) nkz = (2, 20, 1, 1) sx = 'uf' - sy = 'uf' - sz = 'uf' + sy = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' + sz = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor() # Subsample data and trajectory @@ -94,6 +98,15 @@ def test_cart_sampling_op_fwd_adj(sampling): for traj_one_other in trajectory_tensor.unbind(1) ] trajectory = KTrajectory.from_tensor(torch.stack(traj_list, dim=1)) + case 'cartesian_and_non_cartesian': + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + trajectory = KTrajectory.from_tensor(trajectory_tensor) + case 'kx_ky_along_k0_undersampling': + trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)') + random_idx = torch.randperm(trajectory_tensor.shape[-1]) + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]]) case _: raise NotImplementedError(f'Test {sampling} not implemented.') From 54674a95b59165cbf5cb21e1a71d1bb616921554 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 11 Nov 2024 01:55:28 +0100 Subject: [PATCH 03/14] Add reduce_view (#476) Undoes expand, i.e. replaces stride 0 dimensions by size=1 dimensions --- src/mrpro/utils/__init__.py | 3 ++- src/mrpro/utils/reshape.py | 32 ++++++++++++++++++++++++++++++++ tests/utils/test_reshape.py | 32 +++++++++++++++++++++++++++++--- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index c09071f4b..b16fae37a 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -5,10 +5,11 @@ from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop from mrpro.utils.modify_acq_info import modify_acq_info from mrpro.utils.split_idx import split_idx -from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right +from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view __all__ = [ "broadcast_right", "modify_acq_info", + "reduce_view", "remove_repeat", "slice_profiles", "smap", diff --git a/src/mrpro/utils/reshape.py b/src/mrpro/utils/reshape.py index 39d12e51f..31d495afd 100644 --- a/src/mrpro/utils/reshape.py +++ b/src/mrpro/utils/reshape.py @@ -1,5 +1,7 @@ """Tensor reshaping utilities.""" +from collections.abc import Sequence + import torch @@ -67,3 +69,33 @@ def broadcast_right(*x: torch.Tensor) -> tuple[torch.Tensor, ...]: max_dim = max(el.ndim for el in x) unsqueezed = torch.broadcast_tensors(*(unsqueeze_right(el, max_dim - el.ndim) for el in x)) return unsqueezed + + +def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torch.Tensor: + """Reduce expanded dimensions in a view to singletons. + + Reduce either all or specific dimensions to a singleton if it + points to the same memory address. + This undoes expand. + + Parameters + ---------- + x + input tensor + dim + only reduce expanded dimensions in the specified dimensions. + If None, reduce all expanded dimensions. + """ + if dim is None: + dim_: Sequence[int] = range(x.ndim) + elif isinstance(dim, Sequence): + dim_ = [d % x.ndim for d in dim] + else: + dim_ = [dim % x.ndim] + + stride = x.stride() + newsize = [ + 1 if stride == 0 and d in dim_ else oldsize + for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True)) + ] + return torch.as_strided(x, newsize, stride) diff --git a/tests/utils/test_reshape.py b/tests/utils/test_reshape.py index 60a0dc5e3..dd57b8feb 100644 --- a/tests/utils/test_reshape.py +++ b/tests/utils/test_reshape.py @@ -1,7 +1,9 @@ """Tests for reshaping utilities.""" import torch -from mrpro.utils import broadcast_right, unsqueeze_left, unsqueeze_right +from mrpro.utils import broadcast_right, reduce_view, unsqueeze_left, unsqueeze_right + +from tests import RandomGenerator def test_broadcast_right(): @@ -12,7 +14,7 @@ def test_broadcast_right(): def test_unsqueeze_left(): - """Test unsqueeze left""" + """Test unsqueeze_left""" tensor = torch.ones(1, 2, 3) unsqueezed = unsqueeze_left(tensor, 2) assert unsqueezed.shape == (1, 1, 1, 2, 3) @@ -20,8 +22,32 @@ def test_unsqueeze_left(): def test_unsqueeze_right(): - """Test unsqueeze right""" + """Test unsqueeze_right""" tensor = torch.ones(1, 2, 3) unsqueezed = unsqueeze_right(tensor, 2) assert unsqueezed.shape == (1, 2, 3, 1, 1) assert torch.equal(tensor.ravel(), unsqueezed.ravel()) + + +def test_reduce_view(): + """Test reduce_view""" + + tensor = RandomGenerator(0).float32_tensor((1, 2, 3, 1, 1, 1)) + tensor = tensor.expand(1, 2, 3, 4, 1, 1).contiguous() # this cannot be removed + tensor = tensor.expand(7, 2, 3, 4, 5, 6) + + reduced_all = reduce_view(tensor) + assert reduced_all.shape == (1, 2, 3, 4, 1, 1) + assert torch.equal(reduced_all.expand_as(tensor), tensor) + + reduced_two = reduce_view(tensor, (0, -1)) + assert reduced_two.shape == (1, 2, 3, 4, 5, 1) + assert torch.equal(reduced_two.expand_as(tensor), tensor) + + reduced_one_neg = reduce_view(tensor, -1) + assert reduced_one_neg.shape == (7, 2, 3, 4, 5, 1) + assert torch.equal(reduced_one_neg.expand_as(tensor), tensor) + + reduced_one_pos = reduce_view(tensor, 0) + assert reduced_one_pos.shape == (1, 2, 3, 4, 5, 6) + assert torch.equal(reduced_one_pos.expand_as(tensor), tensor) From f0f91c3ff91296a07e6bb7a184684f6674a91795 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 11 Nov 2024 11:35:57 +0100 Subject: [PATCH 04/14] Add apply_ to dataclasses (#505) Applies a function to all children of a dataclass --- src/mrpro/data/AcqInfo.py | 14 ++---- src/mrpro/data/MoveDataMixin.py | 68 ++++++++++++++++++++++------ src/mrpro/data/SpatialDimension.py | 55 +++++----------------- tests/data/test_movedatamixin.py | 19 ++++++++ tests/data/test_spatial_dimension.py | 23 ---------- 5 files changed, 89 insertions(+), 90 deletions(-) diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index 83f752a57..a66224de1 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -1,6 +1,6 @@ """Acquisition information dataclass.""" -from collections.abc import Callable, Sequence +from collections.abc import Sequence from dataclasses import dataclass import ismrmrd @@ -206,17 +206,13 @@ def tensor_2d(data: np.ndarray) -> torch.Tensor: data_tensor = data_tensor[None, None] return data_tensor - def spatialdimension_2d( - data: np.ndarray, conversion: Callable[[torch.Tensor], torch.Tensor] | None = None - ) -> SpatialDimension[torch.Tensor]: + def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: # Ensure spatial dimension is (k1*k2*other, 1, 3) if data.ndim != 2: raise ValueError('Spatial dimension is expected to be of shape (N,3)') data = data[:, None, :] # all spatial dimensions are float32 - return ( - SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))).apply_(conversion) - ) + return SpatialDimension[torch.Tensor].from_array_xyz(torch.tensor(data.astype(np.float32))) acq_idx = AcqIdx( k1=tensor(idx['kspace_encode_step_1']), @@ -251,10 +247,10 @@ def spatialdimension_2d( flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), number_of_samples=tensor_2d(headers['number_of_samples']), - patient_table_position=spatialdimension_2d(headers['patient_table_position'], mm_to_m), + patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), phase_dir=spatialdimension_2d(headers['phase_dir']), physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), - position=spatialdimension_2d(headers['position'], mm_to_m), + position=spatialdimension_2d(headers['position']).apply_(mm_to_m), read_dir=spatialdimension_2d(headers['read_dir']), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index f3f147260..8d977d0a6 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -1,13 +1,13 @@ """MoveDataMixin.""" import dataclasses -from collections.abc import Iterator +from collections.abc import Callable, Iterator from copy import copy as shallowcopy from copy import deepcopy -from typing import ClassVar, TypeAlias +from typing import ClassVar, TypeAlias, cast import torch -from typing_extensions import Any, Protocol, Self, overload, runtime_checkable +from typing_extensions import Any, Protocol, Self, TypeVar, overload, runtime_checkable class InconsistentDeviceError(ValueError): # noqa: D101 @@ -22,6 +22,9 @@ class DataclassInstance(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] +T = TypeVar('T') + + class MoveDataMixin: """Move dataclass fields to cpu/gpu and convert dtypes.""" @@ -151,7 +154,6 @@ def _to( copy: bool = False, memo: dict | None = None, ) -> Self: - new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self """Move data to device and convert dtype if necessary. This method is called by .to(), .cuda(), .cpu(), .double(), and so on. @@ -179,6 +181,8 @@ def _to( memo A dictionary to keep track of already converted objects to avoid multiple conversions. """ + new = shallowcopy(self) if copy or not isinstance(self, torch.nn.Module) else self + if memo is None: memo = {} @@ -219,26 +223,62 @@ def _mixin_to(obj: MoveDataMixin) -> MoveDataMixin: memo=memo, ) - converted: Any - for name, data in new._items(): - if id(data) in memo: - object.__setattr__(new, name, memo[id(data)]) - continue + def _convert(data: T) -> T: + converted: Any # https://github.com/python/mypy/issues/10817 if isinstance(data, torch.Tensor): converted = _tensor_to(data) elif isinstance(data, MoveDataMixin): converted = _mixin_to(data) elif isinstance(data, torch.nn.Module): converted = _module_to(data) - elif copy: - converted = deepcopy(data) else: converted = data - memo[id(data)] = converted - # this works even if new is frozen - object.__setattr__(new, name, converted) + return cast(T, converted) + + # manual recursion allows us to do the copy only once + new.apply_(_convert, memo=memo, recurse=False) return new + def apply_( + self: Self, + function: Callable[[Any], Any] | None = None, + *, + memo: dict[int, Any] | None = None, + recurse: bool = True, + ) -> Self: + """Apply a function to all children in-place. + + Parameters + ---------- + function + The function to apply to all fields. None is interpreted as a no-op. + memo + A dictionary to keep track of objects that the function has already been applied to, + to avoid multiple applications. This is useful if the object has a circular reference. + recurse + If True, the function will be applied to all children that are MoveDataMixin instances. + """ + applied: Any + + if memo is None: + memo = {} + + if function is None: + return self + + for name, data in self._items(): + if id(data) in memo: + # this works even if self is frozen + object.__setattr__(self, name, memo[id(data)]) + continue + if recurse and isinstance(data, MoveDataMixin): + applied = data.apply_(function, memo=memo) + else: + applied = function(data) + memo[id(data)] = applied + object.__setattr__(self, name, applied) + return self + def cuda( self, device: torch.device | str | int | None = None, diff --git a/src/mrpro/data/SpatialDimension.py b/src/mrpro/data/SpatialDimension.py index b5f3dfd27..46b1db89a 100644 --- a/src/mrpro/data/SpatialDimension.py +++ b/src/mrpro/data/SpatialDimension.py @@ -3,14 +3,13 @@ from __future__ import annotations from collections.abc import Callable -from copy import deepcopy from dataclasses import dataclass from typing import Generic, get_args import numpy as np import torch from numpy.typing import ArrayLike -from typing_extensions import Any, Protocol, TypeVar, overload +from typing_extensions import Protocol, Self, TypeVar, overload import mrpro.utils.typing as type_utils from mrpro.data.MoveDataMixin import MoveDataMixin @@ -109,6 +108,16 @@ def from_array_zyx( return SpatialDimension(z, y, x) + def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self: + """Apply a function to each z, y, x (in-place). + + Parameters + ---------- + function + function to apply + """ + return super(SpatialDimension, self).apply_(function) + @property def zyx(self) -> tuple[T_co, T_co, T_co]: """Return a z,y,x tuple.""" @@ -134,48 +143,6 @@ def __setitem__(self: SpatialDimension[T_co_vector], idx: type_utils.TorchIndexe self.y[idx] = other.y self.x[idx] = other.x - def apply_(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: - """Apply function to each of x,y,z in-place. - - Parameters - ---------- - func - function to apply to each of x,y,z - None is interpreted as the identity function. - """ - if func is not None: - self.z = func(self.z) - self.y = func(self.y) - self.x = func(self.x) - return self - - def apply(self: SpatialDimension[T_co], func: Callable[[T_co], T_co] | None = None) -> SpatialDimension[T_co]: - """Apply function to each of x,y,z. - - Parameters - ---------- - func - function to apply to each of x,y,z - None is interpreted as the identity function. - """ - - def func_(x: Any) -> T_co: # noqa: ANN401 - if isinstance(x, torch.Tensor): - # use clone for autograd - x = x.clone() - else: - x = deepcopy(x) - if func is None: - return x - else: - return func(x) - - return self.__class__(func_(self.z), func_(self.y), func_(self.x)) - - def clone(self: SpatialDimension[T_co]) -> SpatialDimension[T_co]: - """Return a deep copy of the SpatialDimension.""" - return self.apply() - @overload def __mul__(self: SpatialDimension[T_co], other: T_co | SpatialDimension[T_co]) -> SpatialDimension[T_co]: ... diff --git a/tests/data/test_movedatamixin.py b/tests/data/test_movedatamixin.py index 06f55a4dc..3feb091de 100644 --- a/tests/data/test_movedatamixin.py +++ b/tests/data/test_movedatamixin.py @@ -23,6 +23,7 @@ class A(MoveDataMixin): """Test class A.""" floattensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0)) + floattensor2: torch.Tensor = field(default_factory=lambda: torch.tensor(-1.0)) complextensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1.0, dtype=torch.complex64)) inttensor: torch.Tensor = field(default_factory=lambda: torch.tensor(1, dtype=torch.int32)) booltensor: torch.Tensor = field(default_factory=lambda: torch.tensor(True)) @@ -204,3 +205,21 @@ def testchild(attribute, expected_dtype): assert original is not new, 'original and new should not be the same object' assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared' + + +def test_movedatamixin_apply(): + """Tests apply_ method of MoveDataMixin.""" + data = B() + # make one of the parameters shared to test memo behavior + data.child.floattensor2 = data.child.floattensor + original = data.clone() + + def multiply_by_2(obj): + if isinstance(obj, torch.Tensor): + return obj * 2 + return obj + + data.apply_(multiply_by_2) + torch.testing.assert_close(data.floattensor, original.floattensor * 2) + torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2 * 2) + assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared' diff --git a/tests/data/test_spatial_dimension.py b/tests/data/test_spatial_dimension.py index cd46854b9..afafece04 100644 --- a/tests/data/test_spatial_dimension.py +++ b/tests/data/test_spatial_dimension.py @@ -115,29 +115,6 @@ def conversion(x: torch.Tensor) -> torch.Tensor: assert torch.equal(spatial_dimension_inplace.z, z) -def test_spatial_dimension_apply(): - """Test apply (out of place)""" - - def conversion(x: torch.Tensor) -> torch.Tensor: - assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor' - return x.swapaxes(0, 1).square() - - xyz = RandomGenerator(0).float32_tensor((1, 2, 3)) - spatial_dimension = SpatialDimension.from_array_xyz(xyz.numpy()) - spatial_dimension_outofplace = spatial_dimension.apply().apply(conversion) - - assert spatial_dimension_outofplace is not spatial_dimension - - assert isinstance(spatial_dimension_outofplace.x, torch.Tensor) - assert isinstance(spatial_dimension_outofplace.y, torch.Tensor) - assert isinstance(spatial_dimension_outofplace.z, torch.Tensor) - - x, y, z = conversion(xyz).unbind(-1) - assert torch.equal(spatial_dimension_outofplace.x, x) - assert torch.equal(spatial_dimension_outofplace.y, y) - assert torch.equal(spatial_dimension_outofplace.z, z) - - def test_spatial_dimension_zyx(): """Test the zyx tuple property""" z, y, x = (2, 3, 4) From a96b9c61f2c667f439124f779e399f374a75c646 Mon Sep 17 00:00:00 2001 From: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:01:46 +0100 Subject: [PATCH 05/14] Fix installation from TestPyPi in workflow (#499) --- .github/workflows/deployment.yml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/deployment.yml b/.github/workflows/deployment.yml index dd900496e..a55b7860d 100644 --- a/.github/workflows/deployment.yml +++ b/.github/workflows/deployment.yml @@ -94,7 +94,15 @@ jobs: run: | VERSION=${{ needs.build-testpypi-package.outputs.version }} SUFFIX=${{ needs.build-testpypi-package.outputs.suffix }} - python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ + for i in {1..3}; do + if python -m pip install mrpro==$VERSION$SUFFIX --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/; then + echo "Package installed successfully." + break + else + echo "Attempt $i failed. Retrying in 10 seconds..." + sleep 10 + fi + done build-pypi-package: name: Build Package for PyPI From 191ab06b35ecae890febb34d0d2841b8fe2cd4c5 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 11 Nov 2024 17:45:56 +0100 Subject: [PATCH 06/14] Remove check-docstring-first pre-commit hook (#508) --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 95d14317a..4790095f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,6 @@ repos: rev: v5.0.0 hooks: - id: check-added-large-files - - id: check-docstring-first - id: check-merge-conflict - id: check-yaml - id: check-toml From 6c54e31ac22e80cfbafad88102b4ca1ab30d8241 Mon Sep 17 00:00:00 2001 From: Patrick Schuenke <37338697+schuenke@users.noreply.github.com> Date: Tue, 12 Nov 2024 00:27:11 +0100 Subject: [PATCH 07/14] Revert NamedTemporaryFile ContextManager in example notebooks (#500) --- .pre-commit-config.yaml | 16 +++++++------- examples/direct_reconstruction.ipynb | 8 +++---- examples/direct_reconstruction.py | 8 +++---- examples/iterative_sense_reconstruction.ipynb | 8 +++---- examples/iterative_sense_reconstruction.py | 8 +++---- examples/pulseq_2d_radial_golden_angle.ipynb | 22 +++++++++---------- examples/pulseq_2d_radial_golden_angle.py | 17 +++++++------- ...rized_iterative_sense_reconstruction.ipynb | 8 +++---- ...ularized_iterative_sense_reconstruction.py | 8 +++---- examples/ruff.toml | 1 + 10 files changed, 51 insertions(+), 53 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4790095f9..303fd43fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,11 +53,11 @@ repos: - "--extra-index-url=https://pypi.python.org/simple" ci: - autofix_commit_msg: | - [pre-commit] auto fixes from pre-commit hooks - autofix_prs: false - autoupdate_branch: '' - autoupdate_commit_msg: '[pre-commit] pre-commit autoupdate' - autoupdate_schedule: monthly - skip: [mypy] - submodules: false + autofix_commit_msg: | + [pre-commit] auto fixes from pre-commit hooks + autofix_prs: false + autoupdate_branch: "" + autoupdate_commit_msg: "[pre-commit] pre-commit autoupdate" + autoupdate_schedule: monthly + skip: [mypy] + submodules: false diff --git a/examples/direct_reconstruction.ipynb b/examples/direct_reconstruction.ipynb index 1e4e74c9c..3b6dc930e 100644 --- a/examples/direct_reconstruction.ipynb +++ b/examples/direct_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { diff --git a/examples/direct_reconstruction.py b/examples/direct_reconstruction.py index 5d55812c9..7672aa7e7 100644 --- a/examples/direct_reconstruction.py +++ b/examples/direct_reconstruction.py @@ -11,10 +11,10 @@ import requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/iterative_sense_reconstruction.ipynb b/examples/iterative_sense_reconstruction.ipynb index f612d7522..87249b2fb 100644 --- a/examples/iterative_sense_reconstruction.ipynb +++ b/examples/iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { diff --git a/examples/iterative_sense_reconstruction.py b/examples/iterative_sense_reconstruction.py index ba5e6a01a..6d0bc49a5 100644 --- a/examples/iterative_sense_reconstruction.py +++ b/examples/iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/pulseq_2d_radial_golden_angle.ipynb b/examples/pulseq_2d_radial_golden_angle.ipynb index bcb4482a1..52e0310bb 100644 --- a/examples/pulseq_2d_radial_golden_angle.ipynb +++ b/examples/pulseq_2d_radial_golden_angle.ipynb @@ -33,14 +33,13 @@ "cell_type": "code", "execution_count": null, "id": "d16f41f1", - "metadata": { - "lines_to_next_cell": 2 - }, + "metadata": {}, "outputs": [], "source": [ "# define zenodo records URL and create a temporary directory and h5-file\n", "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", - "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'\n", + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')" ] }, { @@ -51,10 +50,9 @@ "outputs": [], "source": [ "# Download raw data using requests\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { @@ -127,10 +125,10 @@ "# download the sequence file from zenodo\n", "zenodo_url = 'https://zenodo.org/records/10868061/files/'\n", "seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq'\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file:\n", - " response = requests.get(zenodo_url + seq_fname, timeout=30)\n", - " seq_file.write(response.content)\n", - " seq_file.flush()" + "seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq')\n", + "response = requests.get(zenodo_url + seq_fname, timeout=30)\n", + "seq_file.write(response.content)\n", + "seq_file.flush()" ] }, { diff --git a/examples/pulseq_2d_radial_golden_angle.py b/examples/pulseq_2d_radial_golden_angle.py index f4db5217a..3f857c382 100644 --- a/examples/pulseq_2d_radial_golden_angle.py +++ b/examples/pulseq_2d_radial_golden_angle.py @@ -19,14 +19,13 @@ # define zenodo records URL and create a temporary directory and h5-file zenodo_url = 'https://zenodo.org/records/10854057/files/' fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' - +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') # %% # Download raw data using requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction using KTrajectoryIsmrmrd @@ -63,10 +62,10 @@ # download the sequence file from zenodo zenodo_url = 'https://zenodo.org/records/10868061/files/' seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq' -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file: - response = requests.get(zenodo_url + seq_fname, timeout=30) - seq_file.write(response.content) - seq_file.flush() +seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') +response = requests.get(zenodo_url + seq_fname, timeout=30) +seq_file.write(response.content) +seq_file.flush() # %% # Read raw data and calculate trajectory using KTrajectoryPulseq diff --git a/examples/regularized_iterative_sense_reconstruction.ipynb b/examples/regularized_iterative_sense_reconstruction.ipynb index 0a6743161..6b1c2704b 100644 --- a/examples/regularized_iterative_sense_reconstruction.ipynb +++ b/examples/regularized_iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", - " response = requests.get(zenodo_url + fname, timeout=30)\n", - " data_file.write(response.content)\n", - " data_file.flush()" + "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", + "response = requests.get(zenodo_url + fname, timeout=30)\n", + "data_file.write(response.content)\n", + "data_file.flush()" ] }, { diff --git a/examples/regularized_iterative_sense_reconstruction.py b/examples/regularized_iterative_sense_reconstruction.py index 2ab7ba033..e41dc4ac5 100644 --- a/examples/regularized_iterative_sense_reconstruction.py +++ b/examples/regularized_iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: - response = requests.get(zenodo_url + fname, timeout=30) - data_file.write(response.content) - data_file.flush() +data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') +response = requests.get(zenodo_url + fname, timeout=30) +data_file.write(response.content) +data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/ruff.toml b/examples/ruff.toml index 11a1e6167..1bb114755 100644 --- a/examples/ruff.toml +++ b/examples/ruff.toml @@ -5,4 +5,5 @@ lint.extend-ignore = [ "T20", #print "E402", #module-import-not-at-top-of-file "S101", #assert + "SIM115", #context manager for opening files ] From a89df61455675201b82e86c19b4b8b9743a5068b Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 12 Nov 2024 08:58:11 +0100 Subject: [PATCH 08/14] Adapt KHeader parameters (#506) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/mrpro/data/AcqInfo.py | 41 ++++----- src/mrpro/data/KHeader.py | 48 ++++------ src/mrpro/data/TrajectoryDescription.py | 29 ------ src/mrpro/data/__init__.py | 4 +- src/mrpro/data/_kdata/KData.py | 15 ++-- src/mrpro/data/_kdata/KDataRearrangeMixin.py | 10 +-- src/mrpro/data/_kdata/KDataSelectMixin.py | 9 +- src/mrpro/data/_kdata/KDataSplitMixin.py | 28 +++--- src/mrpro/utils/__init__.py | 7 +- src/mrpro/utils/modify_acq_info.py | 35 -------- src/mrpro/utils/unit_conversion.py | 94 ++++++++++++++++++++ tests/conftest.py | 6 +- tests/data/test_kdata.py | 34 +++++-- tests/utils/test_modify_acq_info.py | 18 ---- tests/utils/test_unit_conversion.py | 82 +++++++++++++++++ 15 files changed, 278 insertions(+), 182 deletions(-) delete mode 100644 src/mrpro/data/TrajectoryDescription.py delete mode 100644 src/mrpro/utils/modify_acq_info.py create mode 100644 src/mrpro/utils/unit_conversion.py delete mode 100644 tests/utils/test_modify_acq_info.py create mode 100644 tests/utils/test_unit_conversion.py diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index a66224de1..f5d677f97 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -6,23 +6,24 @@ import ismrmrd import numpy as np import torch -from typing_extensions import Self, TypeVar +from einops import rearrange +from typing_extensions import Self from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.utils.unit_conversion import mm_to_m -# Conversion functions for units -T = TypeVar('T', float, torch.Tensor) +def rearrange_acq_info_fields(field: object, pattern: str, **axes_lengths: dict[str, int]) -> object: + """Change the shape of the fields in AcqInfo.""" + if isinstance(field, Rotation): + return Rotation.from_matrix(rearrange(field.as_matrix(), pattern, **axes_lengths)) -def ms_to_s(ms: T) -> T: - """Convert ms to s.""" - return ms / 1000 + if isinstance(field, torch.Tensor): + return rearrange(field, pattern, **axes_lengths) - -def mm_to_m(m: T) -> T: - """Convert mm to m.""" - return m / 1000 + return field @dataclass(slots=True) @@ -121,30 +122,24 @@ class AcqInfo(MoveDataMixin): number_of_samples: torch.Tensor """Number of sample points per readout (readouts may have different number of sample points).""" + orientation: Rotation + """Rotation describing the orientation of the readout, phase and slice encoding direction.""" + patient_table_position: SpatialDimension[torch.Tensor] """Offset position of the patient table, in LPS coordinates [m].""" - phase_dir: SpatialDimension[torch.Tensor] - """Directional cosine of phase encoding (2D).""" - physiology_time_stamp: torch.Tensor """Time stamps relative to physiological triggering, e.g. ECG. Not in s but in vendor-specific time units""" position: SpatialDimension[torch.Tensor] """Center of the excited volume, in LPS coordinates relative to isocenter [m].""" - read_dir: SpatialDimension[torch.Tensor] - """Directional cosine of readout/frequency encoding.""" - sample_time_us: torch.Tensor """Readout bandwidth, as time between samples [us].""" scan_counter: torch.Tensor """Zero-indexed incrementing counter for readouts.""" - slice_dir: SpatialDimension[torch.Tensor] - """Directional cosine of slice normal, i.e. cross-product of read_dir and phase_dir.""" - trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists. """Dimensionality of the k-space trajectory vector.""" @@ -247,14 +242,16 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: flags=tensor_2d(headers['flags']), measurement_uid=tensor_2d(headers['measurement_uid']), number_of_samples=tensor_2d(headers['number_of_samples']), + orientation=Rotation.from_directions( + spatialdimension_2d(headers['slice_dir']), + spatialdimension_2d(headers['phase_dir']), + spatialdimension_2d(headers['read_dir']), + ), patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), - phase_dir=spatialdimension_2d(headers['phase_dir']), physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), position=spatialdimension_2d(headers['position']).apply_(mm_to_m), - read_dir=spatialdimension_2d(headers['read_dir']), sample_time_us=tensor_2d(headers['sample_time_us']), scan_counter=tensor_2d(headers['scan_counter']), - slice_dir=spatialdimension_2d(headers['slice_dir']), trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above user_float=tensor_2d(headers['user_float']), user_int=tensor_2d(headers['user_int']), diff --git a/src/mrpro/data/KHeader.py b/src/mrpro/data/KHeader.py index 488c1d9cd..dea12e28b 100644 --- a/src/mrpro/data/KHeader.py +++ b/src/mrpro/data/KHeader.py @@ -13,12 +13,12 @@ from typing_extensions import Self from mrpro.data import enums -from mrpro.data.AcqInfo import AcqInfo, mm_to_m, ms_to_s +from mrpro.data.AcqInfo import AcqInfo from mrpro.data.EncodingLimits import EncodingLimits from mrpro.data.MoveDataMixin import MoveDataMixin from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.TrajectoryDescription import TrajectoryDescription from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues +from mrpro.utils.unit_conversion import mm_to_m, ms_to_s if TYPE_CHECKING: # avoid circular imports by importing only when type checking @@ -40,9 +40,6 @@ class KHeader(MoveDataMixin): trajectory: KTrajectoryCalculator """Function to calculate the k-space trajectory.""" - b0: float - """Magnetic field strength [T].""" - encoding_limits: EncodingLimits """K-space encoding limits.""" @@ -61,12 +58,9 @@ class KHeader(MoveDataMixin): acq_info: AcqInfo """Information of the acquisitions (i.e. readout lines).""" - h1_freq: float + lamor_frequency_proton: float """Lamor frequency of hydrogen nuclei [Hz].""" - n_coils: int | None = None - """Number of receiver coils.""" - datetime: datetime.datetime | None = None """Date and time of acquisition.""" @@ -88,7 +82,7 @@ class KHeader(MoveDataMixin): echo_train_length: int = 1 """Number of echoes in a multi-echo acquisition.""" - seq_type: str = UNKNOWN + sequence_type: str = UNKNOWN """Type of sequence.""" model: str = UNKNOWN @@ -100,16 +94,13 @@ class KHeader(MoveDataMixin): protocol_name: str = UNKNOWN """Name of the acquisition protocol.""" - misc: dict = dataclasses.field(default_factory=dict) # do not use {} here! - """Dictionary with miscellaneous parameters.""" - calibration_mode: enums.CalibrationMode = enums.CalibrationMode.OTHER """Mode of how calibration data is acquired. """ interleave_dim: enums.InterleavingDimension = enums.InterleavingDimension.OTHER """Interleaving dimension.""" - traj_type: enums.TrajectoryType = enums.TrajectoryType.OTHER + trajectory_type: enums.TrajectoryType = enums.TrajectoryType.OTHER """Type of trajectory.""" measurement_id: str = UNKNOWN @@ -118,8 +109,9 @@ class KHeader(MoveDataMixin): patient_name: str = UNKNOWN """Name of the patient.""" - trajectory_description: TrajectoryDescription = dataclasses.field(default_factory=TrajectoryDescription) - """Description of the trajectory.""" + _misc: dict = dataclasses.field(default_factory=dict) # do not use {} here! + """Dictionary with miscellaneous parameters. These parameters are for information purposes only. Reconstruction + algorithms should not rely on them.""" @property def fa_degree(self) -> torch.Tensor | None: @@ -160,17 +152,14 @@ def from_ismrmrd( enc: ismrmrdschema.encodingType = header.encoding[encoding_number] # These are guaranteed to exist - parameters = {'h1_freq': header.experimentalConditions.H1resonanceFrequency_Hz, 'acq_info': acq_info} + parameters = { + 'lamor_frequency_proton': header.experimentalConditions.H1resonanceFrequency_Hz, + 'acq_info': acq_info, + } if defaults is not None: parameters.update(defaults) - if ( - header.acquisitionSystemInformation is not None - and header.acquisitionSystemInformation.receiverChannels is not None - ): - parameters['n_coils'] = header.acquisitionSystemInformation.receiverChannels - if header.sequenceParameters is not None: if header.sequenceParameters.TR: parameters['tr'] = ms_to_s(torch.as_tensor(header.sequenceParameters.TR)) @@ -184,7 +173,7 @@ def from_ismrmrd( parameters['echo_spacing'] = ms_to_s(torch.as_tensor(header.sequenceParameters.echo_spacing)) if header.sequenceParameters.sequence_type is not None: - parameters['seq_type'] = header.sequenceParameters.sequence_type + parameters['sequence_type'] = header.sequenceParameters.sequence_type if enc.reconSpace is not None: parameters['recon_fov'] = SpatialDimension[float].from_xyz(enc.reconSpace.fieldOfView_mm).apply_(mm_to_m) @@ -212,7 +201,7 @@ def from_ismrmrd( ) if enc.trajectory is not None: - parameters['traj_type'] = enums.TrajectoryType(enc.trajectory.value) + parameters['trajectory_type'] = enums.TrajectoryType(enc.trajectory.value) # Either use the series or study time if available if header.measurementInformation is not None and header.measurementInformation.seriesTime is not None: @@ -245,15 +234,8 @@ def from_ismrmrd( if header.acquisitionSystemInformation.systemModel is not None: parameters['model'] = header.acquisitionSystemInformation.systemModel - if header.acquisitionSystemInformation.systemFieldStrength_T is not None: - parameters['b0'] = header.acquisitionSystemInformation.systemFieldStrength_T - - # estimate b0 from h1_freq if not given - if 'b0' not in parameters: - parameters['b0'] = parameters['h1_freq'] / 4258e4 - # Dump everything into misc - parameters['misc'] = dataclasses.asdict(header) + parameters['_misc'] = dataclasses.asdict(header) if overwrite is not None: parameters.update(overwrite) diff --git a/src/mrpro/data/TrajectoryDescription.py b/src/mrpro/data/TrajectoryDescription.py deleted file mode 100644 index 801811005..000000000 --- a/src/mrpro/data/TrajectoryDescription.py +++ /dev/null @@ -1,29 +0,0 @@ -"""TrajectoryDescription dataclass.""" - -import dataclasses -from dataclasses import dataclass - -from ismrmrd.xsd.ismrmrdschema.ismrmrd import trajectoryDescriptionType -from typing_extensions import Self - - -@dataclass(slots=True) -class TrajectoryDescription: - """TrajectoryDescription dataclass.""" - - identifier: str = '' - user_parameter_long: dict[str, int] = dataclasses.field(default_factory=dict) - user_parameter_double: dict[str, float] = dataclasses.field(default_factory=dict) - user_parameter_string: dict[str, str] = dataclasses.field(default_factory=dict) - comment: str = '' - - @classmethod - def from_ismrmrd(cls, trajectory_description: trajectoryDescriptionType) -> Self: - """Create TrajectoryDescription from ismrmrd traj description.""" - return cls( - user_parameter_long={p.name: int(p.value) for p in trajectory_description.userParameterLong}, - user_parameter_double={p.name: float(p.value) for p in trajectory_description.userParameterDouble}, - user_parameter_string={p.name: str(p.value) for p in trajectory_description.userParameterString}, - comment=trajectory_description.comment or '', - identifier=trajectory_description.identifier or '', - ) diff --git a/src/mrpro/data/__init__.py b/src/mrpro/data/__init__.py index b5034d668..d5667a5bc 100644 --- a/src/mrpro/data/__init__.py +++ b/src/mrpro/data/__init__.py @@ -16,7 +16,6 @@ from mrpro.data.QHeader import QHeader from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.data.TrajectoryDescription import TrajectoryDescription __all__ = [ "AcqIdx", "AcqInfo", @@ -37,8 +36,7 @@ "QHeader", "Rotation", "SpatialDimension", - "TrajectoryDescription", "acq_filters", "enums", "traj_calculators" -] \ No newline at end of file +] diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index 409e8aac9..57af617bc 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -18,15 +18,15 @@ from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin from mrpro.data.acq_filters import is_image_acquisition -from mrpro.data.AcqInfo import AcqInfo +from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits from mrpro.data.KHeader import KHeader from mrpro.data.KTrajectory import KTrajectory from mrpro.data.KTrajectoryRawShape import KTrajectoryRawShape from mrpro.data.MoveDataMixin import MoveDataMixin +from mrpro.data.Rotation import Rotation from mrpro.data.traj_calculators.KTrajectoryCalculator import KTrajectoryCalculator from mrpro.data.traj_calculators.KTrajectoryIsmrmrd import KTrajectoryIsmrmrd -from mrpro.utils import modify_acq_info KDIM_SORT_LABELS = ( 'k1', @@ -200,10 +200,13 @@ def from_file( sort_idx = np.lexsort(acq_indices) # torch does not have lexsort as of pytorch 2.2 (March 2024) # Finally, reshape and sort the tensors in acqinfo and acqinfo.idx, and kdata. - def sort_and_reshape_tensor_fields(input_tensor: torch.Tensor): - return rearrange(input_tensor[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2) - - kheader.acq_info = modify_acq_info(sort_and_reshape_tensor_fields, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2 + ) + if isinstance(field, torch.Tensor | Rotation) + else field + ) kdata = rearrange(kdata[sort_idx], '(other k2 k1) coils k0 -> other coils k2 k1 k0', k1=n_k1, k2=n_k2) # Calculate trajectory and check if it matches the kdata shape diff --git a/src/mrpro/data/_kdata/KDataRearrangeMixin.py b/src/mrpro/data/_kdata/KDataRearrangeMixin.py index 05bab7681..23a58dea6 100644 --- a/src/mrpro/data/_kdata/KDataRearrangeMixin.py +++ b/src/mrpro/data/_kdata/KDataRearrangeMixin.py @@ -6,8 +6,7 @@ from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.data.AcqInfo import AcqInfo -from mrpro.utils import modify_acq_info +from mrpro.data.AcqInfo import rearrange_acq_info_fields class KDataRearrangeMixin(_KDataProtocol): @@ -35,9 +34,8 @@ def rearrange_k2_k1_into_k1(self: Self) -> Self: kheader = copy.deepcopy(self.header) # Update shape of acquisition info index - def reshape_acq_info(info: AcqInfo): - return rearrange(info, 'other k2 k1 ... -> other 1 (k2 k1) ...') - - kheader.acq_info = modify_acq_info(reshape_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(field, 'other k2 k1 ... -> other 1 (k2 k1) ...') + ) return type(self)(kheader, kdat, type(self.traj).from_tensor(ktraj)) diff --git a/src/mrpro/data/_kdata/KDataSelectMixin.py b/src/mrpro/data/_kdata/KDataSelectMixin.py index fb0e02aa1..8f8a452cf 100644 --- a/src/mrpro/data/_kdata/KDataSelectMixin.py +++ b/src/mrpro/data/_kdata/KDataSelectMixin.py @@ -7,7 +7,7 @@ from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol -from mrpro.utils import modify_acq_info +from mrpro.data.Rotation import Rotation class KDataSelectMixin(_KDataProtocol): @@ -51,10 +51,9 @@ def select_other_subset( other_idx = torch.cat([torch.where(idx == label_idx[:, 0, 0])[0] for idx in subset_idx], dim=0) # Adapt header - def select_acq_info(info: torch.Tensor): - return info[other_idx, ...] - - kheader.acq_info = modify_acq_info(select_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: field[other_idx, ...] if isinstance(field, torch.Tensor | Rotation) else field + ) # Select data kdat = self.data[other_idx, ...] diff --git a/src/mrpro/data/_kdata/KDataSplitMixin.py b/src/mrpro/data/_kdata/KDataSplitMixin.py index d2c641125..c28004af4 100644 --- a/src/mrpro/data/_kdata/KDataSplitMixin.py +++ b/src/mrpro/data/_kdata/KDataSplitMixin.py @@ -1,15 +1,17 @@ """Mixin class to split KData into other subsets.""" -import copy -from typing import Literal +from typing import Literal, TypeVar, cast import torch from einops import rearrange, repeat from typing_extensions import Self from mrpro.data._kdata.KDataProtocol import _KDataProtocol +from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits -from mrpro.utils import modify_acq_info +from mrpro.data.Rotation import Rotation + +RotationOrTensor = TypeVar('RotationOrTensor', bound=torch.Tensor | Rotation) class KDataSplitMixin(_KDataProtocol): @@ -56,8 +58,9 @@ def _split_k2_or_k1_into_other( def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: return dat_traj[:, :, :, split_idx, :] - def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: - return acq_info[:, :, split_idx, ...] + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + # cast due to https://github.com/python/mypy/issues/10817 + return cast(RotationOrTensor, acq_info[:, :, split_idx, ...]) # Rearrange other_split and k1 dimension rearrange_pattern_data = 'other coils k2 other_split k1 k0->(other other_split) coils k2 k1 k0' @@ -69,8 +72,8 @@ def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: def split_data_traj(dat_traj: torch.Tensor) -> torch.Tensor: return dat_traj[:, :, split_idx, :, :] - def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: - return acq_info[:, split_idx, ...] + def split_acq_info(acq_info: RotationOrTensor) -> RotationOrTensor: + return cast(RotationOrTensor, acq_info[:, split_idx, ...]) # Rearrange other_split and k1 dimension rearrange_pattern_data = 'other coils other_split k2 k1 k0->(other other_split) coils k2 k1 k0' @@ -93,13 +96,14 @@ def split_acq_info(acq_info: torch.Tensor) -> torch.Tensor: ktraj = rearrange(split_data_traj(ktraj), rearrange_pattern_traj) # Create new header with correct shape - kheader = copy.deepcopy(self.header) + kheader = self.header.clone() # Update shape of acquisition info index - def reshape_acq_info(info: torch.Tensor): - return rearrange(split_acq_info(info), rearrange_pattern_acq_info) - - kheader.acq_info = modify_acq_info(reshape_acq_info, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields(split_acq_info(field), rearrange_pattern_acq_info) + if isinstance(field, Rotation | torch.Tensor) + else field + ) # Update other label limits and acquisition info setattr(kheader.encoding_limits, other_label, Limits(min=0, max=n_other - 1, center=0)) diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index b16fae37a..6cd18c2cc 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -1,21 +1,22 @@ import mrpro.utils.slice_profiles import mrpro.utils.typing +import mrpro.utils.unit_conversion from mrpro.utils.smap import smap from mrpro.utils.remove_repeat import remove_repeat from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop -from mrpro.utils.modify_acq_info import modify_acq_info from mrpro.utils.split_idx import split_idx from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view + __all__ = [ "broadcast_right", - "modify_acq_info", "reduce_view", "remove_repeat", "slice_profiles", "smap", "split_idx", "typing", + "unit_conversion", "unsqueeze_left", "unsqueeze_right", "zero_pad_or_crop" -] \ No newline at end of file +] diff --git a/src/mrpro/utils/modify_acq_info.py b/src/mrpro/utils/modify_acq_info.py deleted file mode 100644 index d535e53c5..000000000 --- a/src/mrpro/utils/modify_acq_info.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Modify AcqInfo.""" - -from __future__ import annotations - -import dataclasses -from collections.abc import Callable -from typing import TYPE_CHECKING - -import torch - -if TYPE_CHECKING: - from mrpro.data.AcqInfo import AcqInfo - - -def modify_acq_info(fun_modify: Callable, acq_info: AcqInfo) -> AcqInfo: - """Go through all fields of AcqInfo object and apply changes. - - Parameters - ---------- - fun_modify - Function which takes AcqInfo fields as input and returns modified AcqInfo field - acq_info - AcqInfo object - """ - # Apply function to all fields of acq_info - for field in dataclasses.fields(acq_info): - current = getattr(acq_info, field.name) - if isinstance(current, torch.Tensor): - setattr(acq_info, field.name, fun_modify(current)) - elif dataclasses.is_dataclass(current): - for subfield in dataclasses.fields(current): - subcurrent = getattr(current, subfield.name) - setattr(current, subfield.name, fun_modify(subcurrent)) - - return acq_info diff --git a/src/mrpro/utils/unit_conversion.py b/src/mrpro/utils/unit_conversion.py new file mode 100644 index 000000000..0115bed47 --- /dev/null +++ b/src/mrpro/utils/unit_conversion.py @@ -0,0 +1,94 @@ +"""Conversion between different units.""" + +from typing import TypeVar + +import numpy as np +import torch + +__all__ = [ + 'ms_to_s', + 's_to_ms', + 'mm_to_m', + 'm_to_mm', + 'deg_to_rad', + 'rad_to_deg', + 'lamor_frequency_to_magnetic_field', + 'magnetic_field_to_lamor_frequency', + 'GYROMAGNETIC_RATIO_PROTON', +] + +GYROMAGNETIC_RATIO_PROTON = 42.58 * 1e6 +r"""The gyromagnetic ratio :math:`\frac{\gamma}{2\pi}` of 1H in H20 in Hz/T""" + +# Conversion functions for units +T = TypeVar('T', float, torch.Tensor) + + +def ms_to_s(ms: T) -> T: + """Convert ms to s.""" + return ms / 1000 + + +def s_to_ms(s: T) -> T: + """Convert s to ms.""" + return s * 1000 + + +def mm_to_m(mm: T) -> T: + """Convert mm to m.""" + return mm / 1000 + + +def m_to_mm(m: T) -> T: + """Convert m to mm.""" + return m * 1000 + + +def deg_to_rad(deg: T) -> T: + """Convert degree to radians.""" + if isinstance(deg, torch.Tensor): + return torch.deg2rad(deg) + return deg / 180.0 * np.pi + + +def rad_to_deg(deg: T) -> T: + """Convert radians to degree.""" + if isinstance(deg, torch.Tensor): + return torch.rad2deg(deg) + return deg * 180.0 / np.pi + + +def lamor_frequency_to_magnetic_field(lamor_frequency: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON) -> T: + """Convert the Lamor frequency [Hz] to the magntic field strength [T]. + + Parameters + ---------- + lamor_frequency + Lamor frequency [Hz] + gyromagnetic_ratio + Gyromagnetic ratio [Hz/T], default: gyromagnetic ratio of 1H proton + + Returns + ------- + Magnetic field strength [T] + """ + return lamor_frequency / gyromagnetic_ratio + + +def magnetic_field_to_lamor_frequency( + magnetic_field_strength: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON +) -> T: + """Convert the magntic field strength [T] to Lamor frequency [Hz]. + + Parameters + ---------- + magnetic_field_strength + Strength of the magnetic field [T] + gyromagnetic_ratio + Gyromagnetic ratio [Hz/T], default: gyromagnetic ratio of 1H proton + + Returns + ------- + Lamor frequency [Hz] + """ + return magnetic_field_strength * gyromagnetic_ratio diff --git a/tests/conftest.py b/tests/conftest.py index e3f943462..899e8959c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,9 +45,9 @@ def generate_random_acquisition_properties(generator: RandomGenerator): 'encoding_space_ref': generator.uint16(), 'sample_time_us': generator.float32(), 'position': generator.float32_tuple(3), - 'read_dir': generator.float32_tuple(3), - 'phase_dir': generator.float32_tuple(3), - 'slice_dir': generator.float32_tuple(3), + 'read_dir': (1, 0, 0), # read, phase and slice have to form rotation + 'phase_dir': (0, 1, 0), + 'slice_dir': (0, 0, 1), 'patient_table_position': generator.float32_tuple(3), 'idx': ismrmrd.EncodingCounters(**idx_properties), 'user_int': generator.uint32_tuple(8), diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index 822c63045..b7ec1fa7d 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -2,12 +2,13 @@ import pytest import torch -from einops import rearrange, repeat +from einops import repeat from mrpro.data import KData, KTrajectory, SpatialDimension from mrpro.data.acq_filters import is_coil_calibration_acquisition +from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.traj_calculators.KTrajectoryCalculator import DummyTrajectory from mrpro.operators import FastFourierOp -from mrpro.utils import modify_acq_info, split_idx +from mrpro.utils import split_idx from tests.conftest import RandomGenerator, generate_random_data from tests.data import IsmrmrdRawTestData @@ -77,10 +78,11 @@ def consistently_shaped_kdata(request, random_kheader_shape): # Start with header kheader, n_other, n_coils, n_k2, n_k1, n_k0 = random_kheader_shape - def reshape_acq_data(data): - return rearrange(data, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1) - - kheader.acq_info = modify_acq_info(reshape_acq_data, kheader.acq_info) + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1 + ) + ) # Create kdata with consistent shape kdata = generate_random_data(RandomGenerator(request.param['seed']), (n_other, n_coils, n_k2, n_k1, n_k0)) @@ -162,7 +164,7 @@ def test_KData_kspace(ismrmrd_cart): assert relative_image_difference(reconstructed_img[0, 0, 0, ...], ismrmrd_cart.img_ref) <= 0.05 -@pytest.mark.parametrize(('field', 'value'), [('b0', 11.3), ('tr', torch.tensor([24.3]))]) +@pytest.mark.parametrize(('field', 'value'), [('lamor_frequency_proton', 42.88 * 1e6), ('tr', torch.tensor([24.3]))]) def test_KData_modify_header(ismrmrd_cart, field, value): """Overwrite some parameters in the header.""" parameter_dict = {field: value} @@ -469,3 +471,21 @@ def test_KData_remove_readout_os(monkeypatch, random_kheader): # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high # differences along the edges of the elliptic objects. assert relative_image_difference(torch.abs(img_recon), img_tensor[:, 0, ...]) <= 0.05 + + +def test_modify_acq_info(random_kheader_shape): + """Test the modification of the acquisition info.""" + # Create random header where AcqInfo fields are of shape [n_k1*n_k2] and reshape to [n_other, n_k2, n_k1] + kheader, n_other, _, n_k2, n_k1, _ = random_kheader_shape + + kheader.acq_info.apply_( + lambda field: rearrange_acq_info_fields( + field, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1 + ) + ) + + # Verify shape + assert kheader.acq_info.center_sample.shape == (n_other, n_k2, n_k1, 1) + assert kheader.acq_info.idx.k1.shape == (n_other, n_k2, n_k1) + assert kheader.acq_info.orientation.shape == (n_other, n_k2, n_k1, 1) + assert kheader.acq_info.position.z.shape == (n_other, n_k2, n_k1, 1) diff --git a/tests/utils/test_modify_acq_info.py b/tests/utils/test_modify_acq_info.py deleted file mode 100644 index 451303d02..000000000 --- a/tests/utils/test_modify_acq_info.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests for modification of acquisition infos.""" - -from einops import rearrange -from mrpro.utils import modify_acq_info - - -def test_modify_acq_info(random_kheader_shape): - """Test the modification of the acquisition info.""" - # Create random header where AcqInfo fields are of shape [n_k1*n_k2] and reshape to [n_other, n_k2, n_k1] - kheader, n_other, _, n_k2, n_k1, _ = random_kheader_shape - - def reshape_acq_data(data): - return rearrange(data, '(other k2 k1) ... -> other k2 k1 ...', other=n_other, k2=n_k2, k1=n_k1) - - kheader.acq_info = modify_acq_info(reshape_acq_data, kheader.acq_info) - - # Verify shape - assert kheader.acq_info.center_sample.shape == (n_other, n_k2, n_k1, 1) diff --git a/tests/utils/test_unit_conversion.py b/tests/utils/test_unit_conversion.py new file mode 100644 index 000000000..a232a5366 --- /dev/null +++ b/tests/utils/test_unit_conversion.py @@ -0,0 +1,82 @@ +"""Tests of unit conversion.""" + +import numpy as np +import torch +from mrpro.utils.unit_conversion import ( + deg_to_rad, + lamor_frequency_to_magnetic_field, + m_to_mm, + magnetic_field_to_lamor_frequency, + mm_to_m, + ms_to_s, + rad_to_deg, + s_to_ms, +) + +from tests import RandomGenerator + + +def test_mm_to_m(): + """Verify mm to m conversion.""" + generator = RandomGenerator(seed=0) + mm_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(mm_to_m(mm_input), mm_input / 1000.0) + + +def test_m_to_mm(): + """Verify m to mm conversion.""" + generator = RandomGenerator(seed=0) + m_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(m_to_mm(m_input), m_input * 1000.0) + + +def test_ms_to_s(): + """Verify ms to s conversion.""" + generator = RandomGenerator(seed=0) + ms_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(ms_to_s(ms_input), ms_input / 1000.0) + + +def test_s_to_ms(): + """Verify s to ms conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(s_to_ms(s_input), s_input * 1000.0) + + +def test_rad_to_deg_tensor(): + """Verify radians to degree conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(rad_to_deg(s_input), torch.rad2deg(s_input)) + + +def test_deg_to_rad_tensor(): + """Verify degree to radians conversion.""" + generator = RandomGenerator(seed=0) + s_input = generator.float32_tensor((3, 4, 5)) + torch.testing.assert_close(deg_to_rad(s_input), torch.deg2rad(s_input)) + + +def test_rad_to_deg_float(): + """Verify radians to degree conversion.""" + assert rad_to_deg(np.pi / 2) == 90.0 + + +def test_deg_to_rad_float(): + """Verify degree to radians conversion.""" + assert deg_to_rad(180.0) == np.pi + + +def test_lamor_frequency_to_magnetic_field(): + """Verify conversion of lamor frequency to magnetic field.""" + proton_gyromagnetic_ratio = 42.58 * 1e6 + proton_lamor_frequency_at_3tesla = 127.74 * 1e6 + assert lamor_frequency_to_magnetic_field(proton_lamor_frequency_at_3tesla, proton_gyromagnetic_ratio) == 3.0 + + +def test_magnetic_field_to_lamor_frequency(): + """Verify conversion of magnetic field to lamor frequency.""" + proton_gyromagnetic_ratio = 42.58 * 1e6 + magnetic_field_strength = 3.0 + assert magnetic_field_to_lamor_frequency(magnetic_field_strength, proton_gyromagnetic_ratio) == 127.74 * 1e6 From 84b983cda06f3408c3b255866eb5c1087f7f2e48 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 12 Nov 2024 10:51:20 +0100 Subject: [PATCH 09/14] Release v0.241112 (#510) --- src/mrpro/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/VERSION b/src/mrpro/VERSION index 0f6ae6fb6..c60039027 100644 --- a/src/mrpro/VERSION +++ b/src/mrpro/VERSION @@ -1 +1 @@ -0.241029 +0.241112 From 762fcd777adca6427751a2d00015748d653416b4 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 12 Nov 2024 12:29:54 +0100 Subject: [PATCH 10/14] Select k-space data based on n_coils (#309) Co-authored-by: Felix F Zimmermann --- src/mrpro/data/_kdata/KData.py | 25 +++++++++++++++++- src/mrpro/data/acq_filters.py | 17 ++++++++++++ src/mrpro/utils/__init__.py | 1 + tests/data/_IsmrmrdRawTestData.py | 28 +++++++++++++------- tests/data/test_kdata.py | 44 ++++++++++++++++++++++++++++++- 5 files changed, 104 insertions(+), 11 deletions(-) diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index 57af617bc..d43fc49cd 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -17,7 +17,7 @@ from mrpro.data._kdata.KDataRemoveOsMixin import KDataRemoveOsMixin from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin -from mrpro.data.acq_filters import is_image_acquisition +from mrpro.data.acq_filters import has_n_coils, is_image_acquisition from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields from mrpro.data.EncodingLimits import Limits from mrpro.data.KHeader import KHeader @@ -110,6 +110,29 @@ def from_file( modification_time = datetime.datetime.fromtimestamp(mtime) acquisitions = [acq for acq in acquisitions if acquisition_filter_criterion(acq)] + + # we need the same number of receiver coils for all acquisitions + n_coils_available = {acq.data.shape[0] for acq in acquisitions} + if len(n_coils_available) > 1: + if ( + ismrmrd_header.acquisitionSystemInformation is not None + and ismrmrd_header.acquisitionSystemInformation.receiverChannels is not None + ): + n_coils = int(ismrmrd_header.acquisitionSystemInformation.receiverChannels) + else: + # most likely, highest number of elements are the coils used for imaging + n_coils = int(max(n_coils_available)) + + warnings.warn( + f'Acquisitions with different number {n_coils_available} of receiver coil elements detected.' + 'Data with {n_coils} receiver coil elements will be used.', + stacklevel=1, + ) + acquisitions = [acq for acq in acquisitions if has_n_coils(n_coils, acq)] + + if not acquisitions: + raise ValueError('No acquisitions meeting the given filter criteria were found.') + kdata = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions]) acqinfo = AcqInfo.from_ismrmrd_acquisitions(acquisitions) diff --git a/src/mrpro/data/acq_filters.py b/src/mrpro/data/acq_filters.py index d64c4d9a6..4723d3bba 100644 --- a/src/mrpro/data/acq_filters.py +++ b/src/mrpro/data/acq_filters.py @@ -61,3 +61,20 @@ def is_coil_calibration_acquisition(acquisition: ismrmrd.Acquisition) -> bool: """ coil_calibration_flag = AcqFlags.ACQ_IS_PARALLEL_CALIBRATION | AcqFlags.ACQ_IS_PARALLEL_CALIBRATION_AND_IMAGING return coil_calibration_flag.value & acquisition.flags + + +def has_n_coils(n_coils: int, acquisition: ismrmrd.Acquisition) -> bool: + """Test if acquisitions was obtained with a certain number of receiver coils. + + Parameters + ---------- + n_coils + number of receiver coils + acquisition + ISMRMRD acquisition + + Returns + ------- + True if the acquisition was obtained with n_coils receiver coils + """ + return acquisition.data.shape[0] == n_coils diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 6cd18c2cc..80ef9d398 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -6,6 +6,7 @@ from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop from mrpro.utils.split_idx import split_idx from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view +import mrpro.utils.unit_conversion __all__ = [ "broadcast_right", diff --git a/tests/data/_IsmrmrdRawTestData.py b/tests/data/_IsmrmrdRawTestData.py index 59c42be15..181be56d6 100644 --- a/tests/data/_IsmrmrdRawTestData.py +++ b/tests/data/_IsmrmrdRawTestData.py @@ -67,6 +67,7 @@ def __init__( trajectory_type: Literal['cartesian', 'radial'] = 'cartesian', sampling_order: Literal['linear', 'low_high', 'high_low', 'random'] = 'linear', phantom: EllipsePhantom | None = None, + add_bodycoil_acquisitions: bool = False, n_separate_calibration_lines: int = 0, ): if not phantom: @@ -222,23 +223,32 @@ def __init__( acq.phase_dir[1] = 1.0 acq.slice_dir[2] = 1.0 - # Initialize an acquisition counter - counter = 0 + scan_counter = 0 # Write out a few noise scans for _ in range(32): noise = self.noise_level * torch.randn(self.n_coils, n_freq_encoding, dtype=torch.complex64) # here's where we would make the noise correlated - acq.scan_counter = counter + acq.scan_counter = scan_counter acq.clearAllFlags() acq.setFlag(ismrmrd.ACQ_IS_NOISE_MEASUREMENT) acq.data[:] = noise.numpy() dataset.append_acquisition(acq) - counter += 1 # increment the scan counter + scan_counter += 1 + + # Add acquisitions obtained with a 2-element body coil (e.g. used for adjustment scans) + if add_bodycoil_acquisitions: + acq.resize(n_freq_encoding, 2, trajectory_dimensions=2) + for _ in range(8): + acq.scan_counter = scan_counter + acq.clearAllFlags() + acq.data[:] = torch.randn(2, n_freq_encoding, dtype=torch.complex64) + dataset.append_acquisition(acq) + scan_counter += 1 + acq.resize(n_freq_encoding, self.n_coils, trajectory_dimensions=2) # Calibration lines if n_separate_calibration_lines > 0: - # we take calibration lines around the k-space center traj_ky_calibration, traj_kx_calibration, kpe_calibration = self._cartesian_trajectory( n_separate_calibration_lines, n_freq_encoding, @@ -253,7 +263,7 @@ def __init__( for pe_idx, pe_pos in enumerate(kpe_calibration): # Set some fields in the header - acq.scan_counter = counter + acq.scan_counter = scan_counter # kpe is in the range [-npe//2, npe//2), the ismrmrd kspace_encoding_step_1 is in the range [0, npe) kspace_encoding_step_1 = pe_pos + n_phase_encoding // 2 @@ -264,7 +274,7 @@ def __init__( # Set the data and append acq.data[:] = kspace_calibration[:, :, pe_idx].numpy() dataset.append_acquisition(acq) - counter += 1 + scan_counter += 1 # Loop over the repetitions, add noise and write to disk for rep in range(self.repetitions): @@ -275,7 +285,7 @@ def __init__( for pe_idx, pe_pos in enumerate(kpe[rep]): if not self.flag_invalid_reps or rep == 0 or pe_idx < len(kpe[rep]) // 2: # fewer lines for rep > 0 # Set some fields in the header - acq.scan_counter = counter + acq.scan_counter = scan_counter # kpe is in the range [-npe//2, npe//2), the ismrmrd kspace_encoding_step_1 is in the range [0, npe) kspace_encoding_step_1 = pe_pos + n_phase_encoding // 2 @@ -298,7 +308,7 @@ def __init__( # Set the data and append acq.data[:] = kspace_with_noise[:, :, pe_idx].numpy() dataset.append_acquisition(acq) - counter += 1 + scan_counter += 1 # Clean up dataset.close() diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index b7ec1fa7d..ab4f5aabb 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -4,7 +4,7 @@ import torch from einops import repeat from mrpro.data import KData, KTrajectory, SpatialDimension -from mrpro.data.acq_filters import is_coil_calibration_acquisition +from mrpro.data.acq_filters import has_n_coils, is_coil_calibration_acquisition, is_image_acquisition from mrpro.data.AcqInfo import rearrange_acq_info_fields from mrpro.data.traj_calculators.KTrajectoryCalculator import DummyTrajectory from mrpro.operators import FastFourierOp @@ -29,6 +29,20 @@ def ismrmrd_cart(ellipse_phantom, tmp_path_factory): return ismrmrd_kdata +@pytest.fixture(scope='session') +def ismrmrd_cart_bodycoil_and_surface_coil(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set with bodycoil and surface coil data.""" + ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' + ismrmrd_kdata = IsmrmrdRawTestData( + filename=ismrmrd_filename, + noise_level=0.0, + repetitions=3, + phantom=ellipse_phantom.phantom, + add_bodycoil_acquisitions=True, + ) + return ismrmrd_kdata + + @pytest.fixture(scope='session') def ismrmrd_cart_with_calibration_lines(ellipse_phantom, tmp_path_factory): """Undersampled Cartesian data set with calibration lines.""" @@ -126,6 +140,34 @@ def test_KData_raise_wrong_trajectory_shape(ismrmrd_cart): _ = KData.from_file(ismrmrd_cart.filename, trajectory) +def test_KData_raise_warning_for_bodycoil(ismrmrd_cart_bodycoil_and_surface_coil): + """Mix of bodycoil and surface coil acquisitions leads to warning.""" + with pytest.raises(UserWarning, match='Acquisitions with different number'): + _ = KData.from_file(ismrmrd_cart_bodycoil_and_surface_coil.filename, DummyTrajectory()) + + +@pytest.mark.filterwarnings('ignore:Acquisitions with different number:UserWarning') +def test_KData_select_bodycoil_via_filter(ismrmrd_cart_bodycoil_and_surface_coil): + """Bodycoil can be selected via a custom acquisition filter.""" + # This is the recommended way of selecting the body coil (i.e. 2 receiver elements) + kdata = KData.from_file( + ismrmrd_cart_bodycoil_and_surface_coil.filename, + DummyTrajectory(), + acquisition_filter_criterion=lambda acq: has_n_coils(2, acq) and is_image_acquisition(acq), + ) + assert kdata.data.shape[-4] == 2 + + +def test_KData_raise_wrong_coil_number(ismrmrd_cart): + """Wrong number of coils leads to empty acquisitions.""" + with pytest.raises(ValueError, match='No acquisitions meeting the given filter criteria were found'): + _ = KData.from_file( + ismrmrd_cart.filename, + DummyTrajectory(), + acquisition_filter_criterion=lambda acq: has_n_coils(2, acq) and is_image_acquisition(acq), + ) + + def test_KData_from_file_diff_nky_for_rep(ismrmrd_cart_invalid_reps): """Multiple repetitions with different number of phase encoding lines.""" with pytest.warns(UserWarning, match=r'different number'): From 455679547c70e55aff86c40c0b1ee1ab98742e90 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 12 Nov 2024 13:01:04 +0100 Subject: [PATCH 11/14] Fix formatting in warning string (#514) --- src/mrpro/data/_kdata/KData.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index d43fc49cd..aaf430497 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -124,8 +124,8 @@ def from_file( n_coils = int(max(n_coils_available)) warnings.warn( - f'Acquisitions with different number {n_coils_available} of receiver coil elements detected.' - 'Data with {n_coils} receiver coil elements will be used.', + f'Acquisitions with different number {n_coils_available} of receiver coil elements detected. ' + f'Data with {n_coils} receiver coil elements will be used.', stacklevel=1, ) acquisitions = [acq for acq in acquisitions if has_n_coils(n_coils, acq)] From a8214130948447f27e1a1c540b366f865acfae5d Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 12 Nov 2024 13:05:32 +0100 Subject: [PATCH 12/14] Add human readable ids to test (#485) --- tests/algorithms/test_cg.py | 3 +- tests/conftest.py | 300 +++++++++--------- tests/data/test_rotation.py | 8 +- tests/data/test_trajectory.py | 20 +- tests/operators/functionals/conftest.py | 6 +- tests/operators/models/conftest.py | 20 +- tests/operators/test_cartesian_sampling_op.py | 16 +- tests/operators/test_fourier_op.py | 36 ++- tests/operators/test_rearrangeop.py | 1 + 9 files changed, 217 insertions(+), 193 deletions(-) diff --git a/tests/algorithms/test_cg.py b/tests/algorithms/test_cg.py index 8a4434e2a..4abda2a98 100644 --- a/tests/algorithms/test_cg.py +++ b/tests/algorithms/test_cg.py @@ -16,7 +16,8 @@ (1, 32, False), (4, 32, True), (4, 32, False), - ] + ], + ids=['complex_single', 'real_single', 'complex_batch', 'real_batch'], ) def system(request): """Generate data for creating a system Hx=b with linear and self-adjoint diff --git a/tests/conftest.py b/tests/conftest.py index 899e8959c..1fb7fb95f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -233,16 +233,16 @@ def create_uniform_traj(nk, k_shape): return k -def create_traj(k_shape, nkx, nky, nkz, sx, sy, sz): +def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): """Create trajectory with random entries.""" random_generator = RandomGenerator(seed=0) k_list = [] - for spacing, nk in zip([sz, sy, sx], [nkz, nky, nkx], strict=True): - if spacing == 'nuf': + for spacing, nk in zip([type_kz, type_ky, type_kx], [nkz, nky, nkx], strict=True): + if spacing == 'non-uniform': k = random_generator.float32_tensor(size=nk) - elif spacing == 'uf': + elif spacing == 'uniform': k = create_uniform_traj(nk, k_shape=k_shape) - elif spacing == 'z': + elif spacing == 'zero': k = torch.zeros(nk) k_list.append(k) trajectory = KTrajectory(k_list[0], k_list[1], k_list[2], repeat_detection_tolerance=None) @@ -250,161 +250,163 @@ def create_traj(k_shape, nkx, nky, nkz, sx, sy, sz): COMMON_MR_TRAJECTORIES = pytest.mark.parametrize( - ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz', 's0', 's1', 's2'), + ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz', 'type_k0', 'type_k1', 'type_k2'), [ - # (0) 2d cart mri with 1 coil, no oversampling - ( - (1, 1, 1, 96, 128), # img shape - (1, 1, 1, 96, 128), # k shape - (1, 1, 1, 128), # kx - (1, 1, 96, 1), # ky - (1, 1, 1, 1), # kz - 'uf', # kx is uniform - 'uf', # ky is uniform - 'z', # zero so no Fourier transform is performed along that dimension - 'uf', # k0 is uniform - 'uf', # k1 is uniform - 'z', # k2 is singleton + ( # (0) 2d Cartesian single coil, no oversampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 96, 128), # k_shape + (1, 1, 1, 128), # nkx + (1, 1, 96, 1), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (1) 2d cart mri with 1 coil, with oversampling - ( - (1, 1, 1, 96, 128), - (1, 1, 1, 128, 192), - (1, 1, 1, 192), - (1, 1, 128, 1), - (1, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'uf', - 'z', + ( # (1) 2d Cartesian single coil, with oversampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 128, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 128, 1), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (2) 2d non-Cartesian mri with 2 coils - ( - (1, 2, 1, 96, 128), - (1, 2, 1, 16, 192), - (1, 1, 16, 192), - (1, 1, 16, 192), - (1, 1, 1, 1), - 'nuf', # kx is non-uniform - 'nuf', - 'z', - 'nuf', - 'nuf', - 'z', + ( # (2) 2d non-Cartesian mri with 2 coils + (1, 2, 1, 96, 128), # im_shape + (1, 2, 1, 16, 192), # k_shape + (1, 1, 16, 192), # nkx + (1, 1, 16, 192), # nky + (1, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'zero', # type_k2 ), - # (3) 2d cart mri with irregular sampling - ( - (1, 1, 1, 96, 128), - (1, 1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'z', - 'z', + ( # (3) 2d Cartesian with irregular sampling + (1, 1, 1, 96, 128), # im_shape + (1, 1, 1, 1, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 1, 192), # nky + (1, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'zero', # type_k1 + 'zero', # type_k2 ), - # (4) 2d single shot spiral - ( - (1, 2, 1, 96, 128), - (1, 1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 192), - (1, 1, 1, 1), - 'nuf', - 'nuf', - 'z', - 'nuf', - 'z', - 'z', + ( # (4) 2d single shot spiral + (1, 2, 1, 96, 128), # im_shape + (1, 1, 1, 1, 192), # k_shape + (1, 1, 1, 192), # nkx + (1, 1, 1, 192), # nky + (1, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'zero', # type_k1 + 'zero', # type_k2 ), - # (5) 3d nuFFT mri, 4 coils, 2 other - ( - (2, 4, 16, 32, 64), - (2, 4, 16, 32, 64), - (2, 16, 32, 64), - (2, 16, 32, 64), - (2, 16, 32, 64), - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', + ( # (5) 3d non-uniform, 4 coils, 2 other + (2, 4, 16, 32, 64), # im_shape + (2, 4, 16, 32, 64), # k_shape + (2, 16, 32, 64), # nkx + (2, 16, 32, 64), # nky + (2, 16, 32, 64), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (6) 2d nuFFT cine mri with 8 cardiac phases, 5 coils - ( - (8, 5, 1, 64, 64), - (8, 5, 1, 18, 128), - (8, 1, 18, 128), - (8, 1, 18, 128), - (8, 1, 1, 1), - 'nuf', - 'nuf', - 'z', - 'nuf', - 'nuf', - 'z', + ( # (6) 2d non-uniform cine with 8 cardiac phases, 5 coils + (8, 5, 1, 64, 64), # im_shape + (8, 5, 1, 18, 128), # k_shape + (8, 1, 18, 128), # nkx + (8, 1, 18, 128), # nky + (8, 1, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'zero', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'zero', # type_k2 ), - # (7) 2d cart cine mri with 9 cardiac phases, 6 coils - ( - (9, 6, 1, 96, 128), - (9, 6, 1, 128, 192), - (9, 1, 1, 192), - (9, 1, 128, 1), - (9, 1, 1, 1), - 'uf', - 'uf', - 'z', - 'uf', - 'uf', - 'z', + ( # (7) 2d cartesian cine with 9 cardiac phases, 6 coils + (9, 6, 1, 96, 128), # im_shape + (9, 6, 1, 128, 192), # k_shape + (9, 1, 1, 192), # nkx + (9, 1, 128, 1), # nky + (9, 1, 1, 1), # nkz + 'uniform', # type_kx + 'uniform', # type_ky + 'zero', # type_kz + 'uniform', # type_k0 + 'uniform', # type_k1 + 'zero', # type_k2 ), - # (8) radial phase encoding (RPE), 8 coils, with oversampling in both FFT and nuFFT directions - ( - (2, 8, 64, 32, 48), - (2, 8, 8, 64, 96), - (2, 1, 1, 96), - (2, 8, 64, 1), - (2, 8, 64, 1), - 'uf', - 'nuf', - 'nuf', - 'uf', - 'nuf', - 'nuf', + ( # (8) radial phase encoding (RPE), 8 coils, with oversampling in both FFT and non-uniform directions + (2, 8, 64, 32, 48), # im_shape + (2, 8, 8, 64, 96), # k_shape + (2, 1, 1, 96), # nkx + (2, 8, 64, 1), # nky + (2, 8, 64, 1), # nkz + 'uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (9) radial phase encoding (RPE) , 8 coils with non-Cartesian sampling along readout - ( - (2, 8, 64, 32, 48), - (2, 8, 8, 64, 96), - (2, 1, 1, 96), - (2, 8, 64, 1), - (2, 8, 64, 1), - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', - 'nuf', + ( # (9) radial phase encoding (RPE), 8 coils with non-Cartesian sampling along readout + (2, 8, 64, 32, 48), # im_shape + (2, 8, 8, 64, 96), # k_shape + (2, 1, 1, 96), # nkx + (2, 8, 64, 1), # nky + (2, 8, 64, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'non-uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'non-uniform', # type_k2 ), - # (10) stack of stars, 5 other, 3 coil, oversampling in both FFT and nuFFT directions - ( - (5, 3, 48, 16, 32), - (5, 3, 96, 18, 64), - (5, 1, 18, 64), - (5, 1, 18, 64), - (5, 96, 1, 1), - 'nuf', - 'nuf', - 'uf', - 'nuf', - 'nuf', - 'uf', + ( # (10) stack of stars, 5 other, 3 coil, oversampling in both FFT and non-uniform directions + (5, 3, 48, 16, 32), # im_shape + (5, 3, 96, 18, 64), # k_shape + (5, 1, 18, 64), # nkx + (5, 1, 18, 64), # nky + (5, 96, 1, 1), # nkz + 'non-uniform', # type_kx + 'non-uniform', # type_ky + 'uniform', # type_kz + 'non-uniform', # type_k0 + 'non-uniform', # type_k1 + 'uniform', # type_k2 ), ], + ids=[ + '2d_cartesian_1_coil_no_oversampling', + '2d_cartesian_1_coil_with_oversampling', + '2d_non_cartesian_mri_2_coils', + '2d_cartesian_irregular_sampling', + '2d_single_shot_spiral', + '3d_nonuniform_4_coils_2_other', + '2d_nnonuniform_cine_mri_8_cardiac_phases_5_coils', + '2d_cartesian_cine_9_cardiac_phases_6_coils', + 'radial_phase_encoding_8_coils_with_oversampling', + 'radial_phase_encoding_8_coils_non_cartesian_sampling', + 'stack_of_stars_5_other_3_coil_with_oversampling', + ], ) diff --git a/tests/data/test_rotation.py b/tests/data/test_rotation.py index 6b4cbb52c..035a2d6c6 100644 --- a/tests/data/test_rotation.py +++ b/tests/data/test_rotation.py @@ -535,7 +535,7 @@ def _test_stats(error: torch.Tensor, mean_max: float, rms_max: float) -> None: assert torch.all(rms < rms_max) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): rnd = RandomGenerator(0) @@ -555,7 +555,7 @@ def test_as_euler_asymmetric_axes(seq_tuple, intrinsic): _test_stats(angles_quat - angles, 1e-15, 1e-14) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_symmetric_axes(seq_tuple, intrinsic): rnd = RandomGenerator(0) @@ -576,7 +576,7 @@ def test_as_euler_symmetric_axes(seq_tuple, intrinsic): _test_stats(angles_quat - angles, 1e-16, 1e-14) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic): # Since we cannot check for angle equality, we check for rotation matrix @@ -598,7 +598,7 @@ def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic): torch.testing.assert_close(mat_expected, mat_estimated) -@pytest.mark.parametrize('seq_tuple', permutations('xyz')) +@pytest.mark.parametrize('seq_tuple', permutations('xyz'), ids=str) @pytest.mark.parametrize('intrinsic', [False, True]) def test_as_euler_degenerate_symmetric_axes(seq_tuple, intrinsic): # Since we cannot check for angle equality, we check for rotation matrix diff --git a/tests/data/test_trajectory.py b/tests/data/test_trajectory.py index 1baf4340b..1061a93be 100644 --- a/tests/data/test_trajectory.py +++ b/tests/data/test_trajectory.py @@ -147,16 +147,16 @@ def test_trajectory_cpu(cartesian_grid): @COMMON_MR_TRAJECTORIES -def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): """Test identification of traj types.""" # Generate random k-space trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Find out the type of the kz, ky and kz dimensions - single_value_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'z'] - on_grid_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'uf'] - not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (sz, sy, sx), strict=True) if s == 'nuf'] + single_value_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'z'] + on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'uf'] + not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_kz, type_ky, type_kx), strict=True) if s == 'nuf'] # check dimensions which are of shape 1 and do not need any transform assert all(trajectory.type_along_kzyx[dim] & TrajType.SINGLEVALUE for dim in single_value_dims) @@ -171,16 +171,16 @@ def test_ktype_along_kzyx(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, @COMMON_MR_TRAJECTORIES -def test_ktype_along_k210(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_ktype_along_k210(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): """Test identification of traj types.""" # Generate random k-space trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Find out the type of the k2, k1 and k0 dimensions - single_value_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'z'] - on_grid_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'uf'] - not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (s2, s1, s0), strict=True) if s == 'nuf'] + single_value_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'z'] + on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'uf'] + not_on_grid_dims = [d for d, s in zip((-3, -2, -1), (type_k2, type_k1, type_k0), strict=True) if s == 'nuf'] # check dimensions which are of shape 1 and do not need any transform assert all(trajectory.type_along_k210[dim] & TrajType.SINGLEVALUE for dim in single_value_dims) diff --git a/tests/operators/functionals/conftest.py b/tests/operators/functionals/conftest.py index 9d2bfd8c9..ae8ffc20b 100644 --- a/tests/operators/functionals/conftest.py +++ b/tests/operators/functionals/conftest.py @@ -47,12 +47,12 @@ def result_dtype(self): def functional_test_cases(func: Callable[[FunctionalTestCase], None]) -> Callable[..., None]: """Decorator combining multiple parameterizations for test cases for all proximable functionals.""" - @pytest.mark.parametrize('shape', [[1, 2, 3]]) + @pytest.mark.parametrize('shape', [[1, 2, 3]], ids=['shape=[1,2,3]']) @pytest.mark.parametrize('dtype_name', ['float32', 'complex64']) @pytest.mark.parametrize('weight', ['scalar_weight', 'tensor_weight', 'complex_weight']) @pytest.mark.parametrize('target', ['no_target', 'random_target']) - @pytest.mark.parametrize('dim', [None]) - @pytest.mark.parametrize('divide_by_n', [True, False]) + @pytest.mark.parametrize('dim', [None], ids=['dim=None']) + @pytest.mark.parametrize('divide_by_n', [True, False], ids=['mean', 'sum']) @pytest.mark.parametrize('functional', PROXIMABLE_FUNCTIONALS) def wrapper( functional: type[ElementaryProximableFunctional], diff --git a/tests/operators/models/conftest.py b/tests/operators/models/conftest.py index 570fa2f1f..75fceacd2 100644 --- a/tests/operators/models/conftest.py +++ b/tests/operators/models/conftest.py @@ -8,7 +8,7 @@ SHAPE_VARIATIONS_SIGNAL_MODELS = pytest.mark.parametrize( ('parameter_shape', 'contrast_dim_shape', 'signal_shape'), [ - ((1, 1, 10, 20, 30), (5,), (5, 1, 1, 10, 20, 30)), # single map with different inversion times + ((1, 1, 10, 20, 30), (5,), (5, 1, 1, 10, 20, 30)), # single map with different contrast times ((1, 1, 10, 20, 30), (5, 1), (5, 1, 1, 10, 20, 30)), ((4, 1, 1, 10, 20, 30), (5, 1), (5, 4, 1, 1, 10, 20, 30)), # multiple maps along additional batch dimension ((4, 1, 1, 10, 20, 30), (5,), (5, 4, 1, 1, 10, 20, 30)), @@ -25,6 +25,24 @@ ((1,), (5,), (5, 1)), # single voxel ((4, 3, 1), (5, 4, 3), (5, 4, 3, 1)), ], + ids=[ + 'single_map_diff_contrast_times', + 'single_map_diff_contrast_times_2', + 'multiple_maps_additional_batch_dim', + 'multiple_maps_additional_batch_dim_2', + 'multiple_maps_additional_batch_dim_3', + 'multiple_maps_other_dim', + 'multiple_maps_other_dim_2', + 'multiple_maps_other_dim_3', + 'multiple_maps_other_and_batch_dim', + 'multiple_maps_other_and_batch_dim_2', + 'multiple_maps_other_and_batch_dim_3', + 'multiple_maps_other_and_batch_dim_4', + 'multiple_maps_other_and_batch_dim_5', + 'different_value_each_voxel', + 'single_voxel', + 'multiple_voxels', + ], ) diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 7caa13e7c..28c6e8860 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -17,10 +17,10 @@ def test_cart_sampling_op_data_match(): nkx = (1, 1, 1, 60) nky = (1, 1, 40, 1) nkz = (1, 20, 1, 1) - sx = 'uf' - sy = 'uf' - sz = 'uf' - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + type_kx = 'uniform' + type_ky = 'uniform' + type_kz = 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # Create matching data random_generator = RandomGenerator(seed=0) @@ -73,10 +73,10 @@ def test_cart_sampling_op_fwd_adj(sampling): nkx = (2, 1, 1, 60) nky = (2, 1, 40, 1) nkz = (2, 20, 1, 1) - sx = 'uf' - sy = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' - sz = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf' - trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor() + type_kx = 'uniform' + type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz).as_tensor() # Subsample data and trajectory match sampling: diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 89a4bbc11..6f8c377cf 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -9,22 +9,24 @@ from tests.helper import dotproduct_adjointness_test -def create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): random_generator = RandomGenerator(seed=0) # generate random image img = random_generator.complex64_tensor(size=im_shape) # create random trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) return img, trajectory @COMMON_MR_TRAJECTORIES -def test_fourier_fwd_adj_property(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, s0, s1, s2): +def test_fourier_op_fwd_adj_property( + im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2 +): """Test adjoint property of Fourier operator.""" # generate random images and k-space trajectories - img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz) + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) @@ -42,26 +44,26 @@ def test_fourier_fwd_adj_property(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz, @pytest.mark.parametrize( - ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz'), + ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz'), # parameter names [ - # Cartesian FFT dimensions are not aligned with corresponding k2, k1, k0 dimensions - ( - (5, 3, 48, 16, 32), - (5, 3, 96, 18, 64), - (5, 1, 18, 64), - (5, 96, 1, 1), # Cartesian ky dimension defined along k2 rather than k1 - (5, 1, 18, 64), - 'nuf', - 'uf', - 'nuf', + ( # Cartesian FFT dimensions are not aligned with corresponding k2, k1, k0 dimensions + (5, 3, 48, 16, 32), # im_shape + (5, 3, 96, 18, 64), # k_shape + (5, 1, 18, 64), # nkx + (5, 96, 1, 1), # nky - Cartesian ky dimension defined along k2 rather than k1 + (5, 1, 18, 64), # nkz + 'non-uniform', # type_kx + 'uniform', # type_ky + 'non-uniform', # type_kz ), ], + ids=['cartesian_fft_dims_not_aligned_with_k2_k1_k0_dims'], ) -def test_fourier_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): """Test trajectory not supported by Fourier operator.""" # generate random images and k-space trajectories - img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz) + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) diff --git a/tests/operators/test_rearrangeop.py b/tests/operators/test_rearrangeop.py index 3db888897..054c99402 100644 --- a/tests/operators/test_rearrangeop.py +++ b/tests/operators/test_rearrangeop.py @@ -15,6 +15,7 @@ ((2, 2, 4), '... a b->... (a b)', (2, 8), {'b': 4}), # flatten ((2), '... (a b) -> ... a b', (2, 1), {'b': 1}), # unflatten ], + ids=['swap_axes', 'flatten', 'unflatten'], ) def test_einsum_op(input_shape, rule, output_shape, additional_info, dtype): """Test adjointness and shape.""" From 72c18070d45bbad969bf2244b51163b10f5b627a Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 12 Nov 2024 15:31:57 +0100 Subject: [PATCH 13/14] Add CartesianSamplingOp to FourierOp (#482) Co-authored-by: Felix F Zimmermann --- src/mrpro/operators/FourierOp.py | 41 +++++++++++++++++++----------- tests/conftest.py | 14 ++++++++++ tests/data/test_kdata.py | 13 ---------- tests/operators/test_fourier_op.py | 29 +++++++++++++++++++-- 4 files changed, 67 insertions(+), 30 deletions(-) diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index f4254d8d2..c57d1fbfd 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -11,6 +11,7 @@ from mrpro.data.enums import TrajType from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp from mrpro.operators.FastFourierOp import FastFourierOp from mrpro.operators.LinearOperator import LinearOperator @@ -67,12 +68,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): self._nufft_dims.append(dim) if self._fft_dims: - self._fast_fourier_op = FastFourierOp( + self._fast_fourier_op: FastFourierOp | None = FastFourierOp( dim=tuple(self._fft_dims), recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims), encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims), ) - + self._cart_sampling_op: CartesianSamplingOp | None = CartesianSamplingOp( + encoding_matrix=encoding_matrix, traj=traj + ) + else: + self._fast_fourier_op = None + self._cart_sampling_op = None # Find dimensions which require NUFFT if self._nufft_dims: fft_dims_k210 = [ @@ -102,20 +108,23 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega] self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction - self._fwd_nufft_op = KbNufft( + self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - self._adj_nufft_op = KbNufftAdjoint( + self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - - self._kshape = traj.broadcasted_shape + else: + self._omega: torch.Tensor | None = None + self._fwd_nufft_op = None + self._adj_nufft_op = None + self._kshape = traj.broadcasted_shape @classmethod def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self: @@ -146,11 +155,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil k-space data with shape: (... coils k2 k1 k0) """ - if len(self._fft_dims): - # FFT - (x,) = self._fast_fourier_op(x) - - if self._nufft_dims: + if self._fwd_nufft_op is not None and self._omega is not None: + # NUFFT Type 2 # we need to move the nufft-dimensions to the end and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils nufft_dims # we could move the permute to __init__ but then we still would need to prepend if len(other)>1 @@ -163,7 +169,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.flatten(end_dim=-len(keep_dims) - 1) # omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims) - # TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape). omega = self._omega.permute(*permute) omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :]) omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) @@ -173,6 +178,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims] x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils x = x.permute(*unpermute) + + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: + # FFT + (x,) = self._cart_sampling_op(self._fast_fourier_op(x)[0]) + return (x,) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: @@ -187,11 +197,12 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil image data with shape: (... coils z y x) """ - if self._fft_dims: + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: # IFFT - (x,) = self._fast_fourier_op.adjoint(x) + (x,) = self._fast_fourier_op.adjoint(self._cart_sampling_op.adjoint(x)[0]) - if self._nufft_dims: + if self._adj_nufft_op is not None and self._omega is not None: + # NUFFT Type 1 # we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils (nufft_dims) keep_dims = [-4, *self._nufft_dims] # -4 is coil diff --git a/tests/conftest.py b/tests/conftest.py index 1fb7fb95f..3bd8946f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from xsdata.models.datatype import XmlDate, XmlTime from tests import RandomGenerator +from tests.data import IsmrmrdRawTestData from tests.phantoms import EllipsePhantomTestData @@ -249,6 +250,19 @@ def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): return trajectory +@pytest.fixture(scope='session') +def ismrmrd_cart(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set.""" + ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' + ismrmrd_kdata = IsmrmrdRawTestData( + filename=ismrmrd_filename, + noise_level=0.0, + repetitions=3, + phantom=ellipse_phantom.phantom, + ) + return ismrmrd_kdata + + COMMON_MR_TRAJECTORIES = pytest.mark.parametrize( ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz', 'type_k0', 'type_k1', 'type_k2'), [ diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index ab4f5aabb..d5cfa0f0c 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -16,19 +16,6 @@ from tests.phantoms import EllipsePhantomTestData -@pytest.fixture(scope='session') -def ismrmrd_cart(ellipse_phantom, tmp_path_factory): - """Fully sampled cartesian data set.""" - ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' - ismrmrd_kdata = IsmrmrdRawTestData( - filename=ismrmrd_filename, - noise_level=0.0, - repetitions=3, - phantom=ellipse_phantom.phantom, - ) - return ismrmrd_kdata - - @pytest.fixture(scope='session') def ismrmrd_cart_bodycoil_and_surface_coil(ellipse_phantom, tmp_path_factory): """Fully sampled cartesian data set with bodycoil and surface coil data.""" diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 6f8c377cf..f48a24260 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -1,7 +1,9 @@ """Tests for Fourier operator.""" import pytest -from mrpro.data import SpatialDimension +import torch +from mrpro.data import KData, KTrajectory, SpatialDimension +from mrpro.data.traj_calculators import KTrajectoryCartesian from mrpro.operators import FourierOp from tests import RandomGenerator @@ -30,7 +32,11 @@ def test_fourier_op_fwd_adj_property( # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) # apply forward operator @@ -70,3 +76,22 @@ def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'): FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) + + +def test_fourier_op_cartesian_sorting(ismrmrd_cart): + """Verify correct sorting of Cartesian k-space data before FFT.""" + kdata = KData.from_file(ismrmrd_cart.filename, KTrajectoryCartesian()) + ff_op = FourierOp.from_kdata(kdata) + (img,) = ff_op.adjoint(kdata.data) + + # shuffle the kspace points along k0 + permutation_index = torch.randperm(kdata.data.shape[-1]) + kdata_unsorted = KData( + header=kdata.header, + data=kdata.data[..., permutation_index], + traj=KTrajectory.from_tensor(kdata.traj.as_tensor()[..., permutation_index]), + ) + ff_op_unsorted = FourierOp.from_kdata(kdata_unsorted) + (img_unsorted,) = ff_op_unsorted.adjoint(kdata_unsorted.data) + + torch.testing.assert_close(img, img_unsorted) From 202d395a2c10f2c3db7c491851988fd26dca14a5 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 12 Nov 2024 17:17:18 +0100 Subject: [PATCH 14/14] Exclude data outside of encoding_matrix (#234) Co-authored-by: Felix Zimmermann --- src/mrpro/operators/CartesianSamplingOp.py | 98 ++++++++++++++++--- tests/conftest.py | 2 +- tests/operators/test_cartesian_sampling_op.py | 25 +++++ tests/operators/test_fourier_op.py | 6 +- 4 files changed, 115 insertions(+), 16 deletions(-) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 7a51924b1..07f8aba65 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -1,5 +1,7 @@ """Cartesian Sampling Operator.""" +import warnings + import torch from einops import rearrange, repeat @@ -7,6 +9,7 @@ from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension from mrpro.operators.LinearOperator import LinearOperator +from mrpro.utils.reshape import unsqueeze_left class CartesianSamplingOp(LinearOperator): @@ -64,10 +67,35 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> # 1D indices into a flattened tensor. kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)') + + # check that all points are inside the encoding matrix + inside_encoding_matrix = ( + ((kx_idx >= 0) & (kx_idx < sorted_grid_shape.x)) + & ((ky_idx >= 0) & (ky_idx < sorted_grid_shape.y)) + & ((kz_idx >= 0) & (kz_idx < sorted_grid_shape.z)) + ) + if not torch.all(inside_encoding_matrix): + warnings.warn( + 'K-space points lie outside of the encoding_matrix and will be ignored.' + ' Increase the encoding_matrix to include these points.', + stacklevel=2, + ) + + inside_encoding_matrix = rearrange(inside_encoding_matrix, '... kz ky kx -> ... 1 (kz ky kx)') + inside_encoding_matrix_idx = inside_encoding_matrix.nonzero(as_tuple=True)[-1] + inside_encoding_matrix_idx = torch.reshape(inside_encoding_matrix_idx, (*kidx.shape[:-1], -1)) + self.register_buffer('_inside_encoding_matrix_idx', inside_encoding_matrix_idx) + kidx = torch.take_along_dim(kidx, inside_encoding_matrix_idx, dim=-1) + else: + self._inside_encoding_matrix_idx: torch.Tensor | None = None + self.register_buffer('_fft_idx', kidx) + # we can skip the indexing if the data is already sorted self._needs_indexing = ( - not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + not torch.all(torch.diff(kidx) == 1) + or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx + or self._inside_encoding_matrix_idx is not None ) self._trajectory_shape = traj.broadcasted_shape @@ -93,8 +121,21 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: return (x,) x_kflat = rearrange(x, '... coil k2_enc k1_enc k0_enc -> ... coil (k2_enc k1_enc k0_enc)') - # take_along_dim does broadcast, so no need for extending here - x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1) + # take_along_dim broadcasts, but needs the same number of dimensions + idx = unsqueeze_left(self._fft_idx, x_kflat.ndim - self._fft_idx.ndim) + x_inside_encoding_matrix = torch.take_along_dim(x_kflat, idx, dim=-1) + + if self._inside_encoding_matrix_idx is None: + # all trajectory points are inside the encoding matrix + x_indexed = x_inside_encoding_matrix + else: + # we need to add zeros + x_indexed = self._broadcast_and_scatter_along_last_dim( + x_inside_encoding_matrix, + self._trajectory_shape[-1] * self._trajectory_shape[-2] * self._trajectory_shape[-3], + self._inside_encoding_matrix_idx, + ) + # reshape to (... other coil, k2, k1, k0) x_reshaped = x_indexed.reshape(x.shape[:-3] + self._trajectory_shape[-3:]) @@ -120,18 +161,13 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)') - # scatter does not broadcast, so we need to manually broadcast the indices - broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1]) - idx_expanded = torch.broadcast_to(self._fft_idx, (*broadcast_shape, self._fft_idx.shape[-1])) + if self._inside_encoding_matrix_idx is not None: + idx = unsqueeze_left(self._inside_encoding_matrix_idx, y_kflat.ndim - self._inside_encoding_matrix_idx.ndim) + y_kflat = torch.take_along_dim(y_kflat, idx, dim=-1) - # although scatter_ is inplace, this will not cause issues with autograd, as self - # is always constant zero and gradients w.r.t. src work as expected. - y_scattered = torch.zeros( - *broadcast_shape, - self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, - dtype=y.dtype, - device=y.device, - ).scatter_(dim=-1, index=idx_expanded, src=y_kflat) + y_scattered = self._broadcast_and_scatter_along_last_dim( + y_kflat, self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, self._fft_idx + ) # reshape to ..., other, coil, k2_enc, k1_enc, k0_enc y_reshaped = y_scattered.reshape( @@ -142,3 +178,37 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: ) return (y_reshaped,) + + @staticmethod + def _broadcast_and_scatter_along_last_dim( + data_to_scatter: torch.Tensor, n_last_dim: int, scatter_index: torch.Tensor + ) -> torch.Tensor: + """Broadcast scatter index and scatter into zero tensor. + + Parameters + ---------- + data_to_scatter + Data to be scattered at indices scatter_index + n_last_dim + Number of data points in last dimension + scatter_index + Indices describing where to scatter data + + Returns + ------- + Data scattered into tensor along scatter_index + """ + # scatter does not broadcast, so we need to manually broadcast the indices + broadcast_shape = torch.broadcast_shapes(scatter_index.shape[:-1], data_to_scatter.shape[:-1]) + idx_expanded = torch.broadcast_to(scatter_index, (*broadcast_shape, scatter_index.shape[-1])) + + # although scatter_ is inplace, this will not cause issues with autograd, as self + # is always constant zero and gradients w.r.t. src work as expected. + data_scattered = torch.zeros( + *broadcast_shape, + n_last_dim, + dtype=data_to_scatter.dtype, + device=data_to_scatter.device, + ).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter) + + return data_scattered diff --git a/tests/conftest.py b/tests/conftest.py index 3bd8946f2..30ae9c229 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -240,7 +240,7 @@ def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): k_list = [] for spacing, nk in zip([type_kz, type_ky, type_kx], [nkz, nky, nkx], strict=True): if spacing == 'non-uniform': - k = random_generator.float32_tensor(size=nk) + k = random_generator.float32_tensor(size=nk, low=-1, high=1) * max(nk) elif spacing == 'uniform': k = create_uniform_traj(nk, k_shape=k_shape) elif spacing == 'zero': diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 28c6e8860..c1738b7bb 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -118,3 +118,28 @@ def test_cart_sampling_op_fwd_adj(sampling): u = random_generator.complex64_tensor(size=k_shape) v = random_generator.complex64_tensor(size=k_shape[:2] + trajectory.as_tensor().shape[2:]) dotproduct_adjointness_test(sampling_op, u, v) + + +@pytest.mark.parametrize(('k2_min', 'k2_max'), [(-1, 21), (-21, 1)]) +@pytest.mark.parametrize(('k0_min', 'k0_max'), [(-6, 13), (-13, 6)]) +def test_cart_sampling_op_oversampling(k0_min, k0_max, k2_min, k2_max): + """Test trajectory points outside of encoding_matrix.""" + encoding_matrix = SpatialDimension(40, 1, 20) + + # Create kx and kz sampling which are asymmetric and larger than the encoding matrix on one side + # The indices are inverted to ensure CartesianSamplingOp acts on them + kx = rearrange(torch.linspace(k0_max, k0_min, 20), 'kx->1 1 1 kx') + ky = torch.ones(1, 1, 1, 1) + kz = rearrange(torch.linspace(k2_max, k2_min, 40), 'kz-> kz 1 1') + kz = torch.stack([kz, -kz], dim=0) # different kz values for two other elements + trajectory = KTrajectory(kz=kz, ky=ky, kx=kx) + + with pytest.warns(UserWarning, match='K-space points lie outside of the encoding_matrix'): + sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) + + random_generator = RandomGenerator(seed=0) + u = random_generator.complex64_tensor(size=(3, 2, 5, kz.shape[-3], ky.shape[-2], kx.shape[-1])) + v = random_generator.complex64_tensor(size=(3, 2, 5, *encoding_matrix.zyx)) + + assert sampling_op.adjoint(u)[0].shape[-3:] == encoding_matrix.zyx + assert sampling_op(v)[0].shape[-3:] == (kz.shape[-3], ky.shape[-2], kx.shape[-1]) diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index f48a24260..2d76642c3 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -73,7 +73,11 @@ def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'): FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory)