Skip to content

Commit

Permalink
Add CartesianSamplingOp to FourierOp (#482)
Browse files Browse the repository at this point in the history
Co-authored-by: Felix F Zimmermann <fzimmermann89@gmail.com>
  • Loading branch information
ckolbPTB and fzimmermann89 authored Nov 12, 2024
1 parent a821413 commit 72c1807
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 30 deletions.
41 changes: 26 additions & 15 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mrpro.data.enums import TrajType
from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp
from mrpro.operators.FastFourierOp import FastFourierOp
from mrpro.operators.LinearOperator import LinearOperator

Expand Down Expand Up @@ -67,12 +68,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
self._nufft_dims.append(dim)

if self._fft_dims:
self._fast_fourier_op = FastFourierOp(
self._fast_fourier_op: FastFourierOp | None = FastFourierOp(
dim=tuple(self._fft_dims),
recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims),
encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims),
)

self._cart_sampling_op: CartesianSamplingOp | None = CartesianSamplingOp(
encoding_matrix=encoding_matrix, traj=traj
)
else:
self._fast_fourier_op = None
self._cart_sampling_op = None
# Find dimensions which require NUFFT
if self._nufft_dims:
fft_dims_k210 = [
Expand Down Expand Up @@ -102,20 +108,23 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction

self._fwd_nufft_op = KbNufft(
self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op = KbNufftAdjoint(
self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)

self._kshape = traj.broadcasted_shape
else:
self._omega: torch.Tensor | None = None
self._fwd_nufft_op = None
self._adj_nufft_op = None
self._kshape = traj.broadcasted_shape

@classmethod
def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self:
Expand Down Expand Up @@ -146,11 +155,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
-------
coil k-space data with shape: (... coils k2 k1 k0)
"""
if len(self._fft_dims):
# FFT
(x,) = self._fast_fourier_op(x)

if self._nufft_dims:
if self._fwd_nufft_op is not None and self._omega is not None:
# NUFFT Type 2
# we need to move the nufft-dimensions to the end and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils nufft_dims
# we could move the permute to __init__ but then we still would need to prepend if len(other)>1
Expand All @@ -163,7 +169,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
x = x.flatten(end_dim=-len(keep_dims) - 1)

# omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims)
# TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape).
omega = self._omega.permute(*permute)
omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :])
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
Expand All @@ -173,6 +178,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims]
x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils
x = x.permute(*unpermute)

if self._fast_fourier_op is not None and self._cart_sampling_op is not None:
# FFT
(x,) = self._cart_sampling_op(self._fast_fourier_op(x)[0])

return (x,)

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
Expand All @@ -187,11 +197,12 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
-------
coil image data with shape: (... coils z y x)
"""
if self._fft_dims:
if self._fast_fourier_op is not None and self._cart_sampling_op is not None:
# IFFT
(x,) = self._fast_fourier_op.adjoint(x)
(x,) = self._fast_fourier_op.adjoint(self._cart_sampling_op.adjoint(x)[0])

if self._nufft_dims:
if self._adj_nufft_op is not None and self._omega is not None:
# NUFFT Type 1
# we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils (nufft_dims)
keep_dims = [-4, *self._nufft_dims] # -4 is coil
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from xsdata.models.datatype import XmlDate, XmlTime

from tests import RandomGenerator
from tests.data import IsmrmrdRawTestData
from tests.phantoms import EllipsePhantomTestData


Expand Down Expand Up @@ -249,6 +250,19 @@ def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz):
return trajectory


@pytest.fixture(scope='session')
def ismrmrd_cart(ellipse_phantom, tmp_path_factory):
"""Fully sampled cartesian data set."""
ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5'
ismrmrd_kdata = IsmrmrdRawTestData(
filename=ismrmrd_filename,
noise_level=0.0,
repetitions=3,
phantom=ellipse_phantom.phantom,
)
return ismrmrd_kdata


COMMON_MR_TRAJECTORIES = pytest.mark.parametrize(
('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz', 'type_k0', 'type_k1', 'type_k2'),
[
Expand Down
13 changes: 0 additions & 13 deletions tests/data/test_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,6 @@
from tests.phantoms import EllipsePhantomTestData


@pytest.fixture(scope='session')
def ismrmrd_cart(ellipse_phantom, tmp_path_factory):
"""Fully sampled cartesian data set."""
ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5'
ismrmrd_kdata = IsmrmrdRawTestData(
filename=ismrmrd_filename,
noise_level=0.0,
repetitions=3,
phantom=ellipse_phantom.phantom,
)
return ismrmrd_kdata


@pytest.fixture(scope='session')
def ismrmrd_cart_bodycoil_and_surface_coil(ellipse_phantom, tmp_path_factory):
"""Fully sampled cartesian data set with bodycoil and surface coil data."""
Expand Down
29 changes: 27 additions & 2 deletions tests/operators/test_fourier_op.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Tests for Fourier operator."""

import pytest
from mrpro.data import SpatialDimension
import torch
from mrpro.data import KData, KTrajectory, SpatialDimension
from mrpro.data.traj_calculators import KTrajectoryCartesian
from mrpro.operators import FourierOp

from tests import RandomGenerator
Expand Down Expand Up @@ -30,7 +32,11 @@ def test_fourier_op_fwd_adj_property(

# create operator
recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1])
encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1])
encoding_matrix = SpatialDimension(
int(trajectory.kz.max() - trajectory.kz.min() + 1),
int(trajectory.ky.max() - trajectory.ky.min() + 1),
int(trajectory.kx.max() - trajectory.kx.min() + 1),
)
fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory)

# apply forward operator
Expand Down Expand Up @@ -70,3 +76,22 @@ def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx
encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1])
with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'):
FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory)


def test_fourier_op_cartesian_sorting(ismrmrd_cart):
"""Verify correct sorting of Cartesian k-space data before FFT."""
kdata = KData.from_file(ismrmrd_cart.filename, KTrajectoryCartesian())
ff_op = FourierOp.from_kdata(kdata)
(img,) = ff_op.adjoint(kdata.data)

# shuffle the kspace points along k0
permutation_index = torch.randperm(kdata.data.shape[-1])
kdata_unsorted = KData(
header=kdata.header,
data=kdata.data[..., permutation_index],
traj=KTrajectory.from_tensor(kdata.traj.as_tensor()[..., permutation_index]),
)
ff_op_unsorted = FourierOp.from_kdata(kdata_unsorted)
(img_unsorted,) = ff_op_unsorted.adjoint(kdata_unsorted.data)

torch.testing.assert_close(img, img_unsorted)

3 comments on commit 72c1807

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/mrpro/algorithms/csm
   inati.py24196%44
   walsh.py16194%34
src/mrpro/algorithms/dcf
   dcf_voronoi.py53492%15, 48–49, 76
src/mrpro/algorithms/optimizers
   adam.py20195%69
src/mrpro/algorithms/reconstruction
   DirectReconstruction.py281643%51–71, 85
   IterativeSENSEReconstruction.py13192%76
   Reconstruction.py502256%42, 54–56, 80–87, 104–113
   RegularizedIterativeSENSEReconstruction.py411759%96–100, 114–139
src/mrpro/data
   AcqInfo.py128398%26, 169, 207
   CsmData.py29390%15, 82–84
   DcfData.py45882%18, 66, 78–83
   IData.py67987%119, 125, 129, 159–167
   IHeader.py75791%75, 109, 127–131
   KHeader.py1531789%25, 119–123, 150, 199, 210, 217–218, 221, 228, 260–271
   KNoise.py311552%39–52, 56–61
   KTrajectory.py69593%178–182
   MoveDataMixin.py1371887%15, 113, 129, 143–145, 207, 305–307, 320, 399, 419–420, 422, 437–438, 440
   QData.py39782%42, 65–73
   Rotation.py6743595%100, 198, 335, 433, 477, 495, 581, 583, 592, 626, 628, 691, 768, 773, 776, 791, 808, 813, 889, 1077, 1082, 1085, 1109, 1113, 1240, 1242, 1250–1251, 1315, 1397, 1690, 1846, 1881, 1885, 1996
   SpatialDimension.py2302191%33, 103, 128, 135, 141, 261–263, 276–278, 312, 330, 343, 356, 369, 382, 391–392, 407, 416
   acq_filters.py12192%47
src/mrpro/data/_kdata
   KData.py1121884%108–109, 124, 131, 141, 149, 203–204, 242, 247–248, 267–278
   KDataRemoveOsMixin.py29293%44, 46
   KDataSelectMixin.py19289%48, 63
   KDataSplitMixin.py48394%53, 84, 93
src/mrpro/data/traj_calculators
   KTrajectoryCalculator.py25292%23, 45
   KTrajectoryIsmrmrd.py13285%41, 50
   KTrajectoryPulseq.py29197%54
src/mrpro/operators
   CartesianSamplingOp.py50296%90, 116
   ConstraintsOp.py60297%46, 48
   EndomorphOperator.py65297%228, 234
   FiniteDifferenceOp.py27293%40, 105
   Functional.py71593%20–22, 117, 119
   GridSamplingOp.py136993%72–73, 82–83, 90–91, 94, 96, 98
   LinearOperator.py1711293%55, 91, 190, 220, 261, 270, 278, 287, 295, 320, 418, 423
   LinearOperatorMatrix.py1581690%82, 119, 152, 161, 166, 175–178, 191–194, 203, 215, 304, 331, 359
   MultiIdentityOp.py13285%43, 48
   Operator.py78297%25, 74
   ProximableFunctionalSeparableSum.py39392%50, 103, 110
   SliceProjectionOp.py173895%44, 61, 63, 69, 206, 227, 260, 300
   WaveletOp.py120596%152, 170, 205, 210, 233
   ZeroPadOp.py16194%30
src/mrpro/utils
   filters.py62297%44, 49
   slice_profiles.py46687%20, 36, 113–116, 149
   sliding_window.py34197%34
   split_idx.py10280%43, 47
   summarize_tensorvalues.py11918%20–29
   typing.py181139%8–23
   zero_pad_or_crop.py31681%26, 30, 54, 57, 60, 63
TOTAL473235093% 

Tests Skipped Failures Errors Time
1967 0 💤 0 ❌ 0 🔥 1m 31s ⏱️

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/mrpro/algorithms/csm
   inati.py24196%44
   walsh.py16194%34
src/mrpro/algorithms/dcf
   dcf_voronoi.py53492%15, 48–49, 76
src/mrpro/algorithms/optimizers
   adam.py20195%69
src/mrpro/algorithms/reconstruction
   DirectReconstruction.py281643%51–71, 85
   IterativeSENSEReconstruction.py13192%76
   Reconstruction.py502256%42, 54–56, 80–87, 104–113
   RegularizedIterativeSENSEReconstruction.py411759%96–100, 114–139
src/mrpro/data
   AcqInfo.py128398%26, 169, 207
   CsmData.py29390%15, 82–84
   DcfData.py45882%18, 66, 78–83
   IData.py67987%119, 125, 129, 159–167
   IHeader.py75791%75, 109, 127–131
   KHeader.py1531789%25, 119–123, 150, 199, 210, 217–218, 221, 228, 260–271
   KNoise.py311552%39–52, 56–61
   KTrajectory.py69593%178–182
   MoveDataMixin.py1371887%15, 113, 129, 143–145, 207, 305–307, 320, 399, 419–420, 422, 437–438, 440
   QData.py39782%42, 65–73
   Rotation.py6743595%100, 198, 335, 433, 477, 495, 581, 583, 592, 626, 628, 691, 768, 773, 776, 791, 808, 813, 889, 1077, 1082, 1085, 1109, 1113, 1240, 1242, 1250–1251, 1315, 1397, 1690, 1846, 1881, 1885, 1996
   SpatialDimension.py2302191%33, 103, 128, 135, 141, 261–263, 276–278, 312, 330, 343, 356, 369, 382, 391–392, 407, 416
   acq_filters.py12192%47
src/mrpro/data/_kdata
   KData.py1121884%108–109, 124, 131, 141, 149, 203–204, 242, 247–248, 267–278
   KDataRemoveOsMixin.py29293%44, 46
   KDataSelectMixin.py19289%48, 63
   KDataSplitMixin.py48394%53, 84, 93
src/mrpro/data/traj_calculators
   KTrajectoryCalculator.py25292%23, 45
   KTrajectoryIsmrmrd.py13285%41, 50
   KTrajectoryPulseq.py29197%54
src/mrpro/operators
   CartesianSamplingOp.py50296%90, 116
   ConstraintsOp.py60297%46, 48
   EndomorphOperator.py65297%228, 234
   FiniteDifferenceOp.py27293%40, 105
   Functional.py71593%20–22, 117, 119
   GridSamplingOp.py136993%72–73, 82–83, 90–91, 94, 96, 98
   LinearOperator.py1711293%55, 91, 190, 220, 261, 270, 278, 287, 295, 320, 418, 423
   LinearOperatorMatrix.py1581690%82, 119, 152, 161, 166, 175–178, 191–194, 203, 215, 304, 331, 359
   MultiIdentityOp.py13285%43, 48
   Operator.py78297%25, 74
   ProximableFunctionalSeparableSum.py39392%50, 103, 110
   SliceProjectionOp.py173895%44, 61, 63, 69, 206, 227, 260, 300
   WaveletOp.py120596%152, 170, 205, 210, 233
   ZeroPadOp.py16194%30
src/mrpro/utils
   filters.py62297%44, 49
   slice_profiles.py46687%20, 36, 113–116, 149
   sliding_window.py34197%34
   split_idx.py10280%43, 47
   summarize_tensorvalues.py11918%20–29
   typing.py181139%8–23
   zero_pad_or_crop.py31681%26, 30, 54, 57, 60, 63
TOTAL473235093% 

Tests Skipped Failures Errors Time
1967 0 💤 0 ❌ 0 🔥 1m 30s ⏱️

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/mrpro/algorithms/csm
   inati.py24196%44
   walsh.py16194%34
src/mrpro/algorithms/dcf
   dcf_voronoi.py53492%15, 48–49, 76
src/mrpro/algorithms/optimizers
   adam.py20195%69
src/mrpro/algorithms/reconstruction
   DirectReconstruction.py281643%51–71, 85
   IterativeSENSEReconstruction.py13192%76
   Reconstruction.py502256%42, 54–56, 80–87, 104–113
   RegularizedIterativeSENSEReconstruction.py411759%96–100, 114–139
src/mrpro/data
   AcqInfo.py128398%26, 169, 207
   CsmData.py29390%15, 82–84
   DcfData.py45882%18, 66, 78–83
   IData.py67987%119, 125, 129, 159–167
   IHeader.py75791%75, 109, 127–131
   KHeader.py1531789%25, 119–123, 150, 199, 210, 217–218, 221, 228, 260–271
   KNoise.py311552%39–52, 56–61
   KTrajectory.py69593%178–182
   MoveDataMixin.py1371887%15, 113, 129, 143–145, 207, 305–307, 320, 399, 419–420, 422, 437–438, 440
   QData.py39782%42, 65–73
   Rotation.py6743595%100, 198, 335, 433, 477, 495, 581, 583, 592, 626, 628, 691, 768, 773, 776, 791, 808, 813, 889, 1077, 1082, 1085, 1109, 1113, 1240, 1242, 1250–1251, 1315, 1397, 1690, 1846, 1881, 1885, 1996
   SpatialDimension.py2302191%33, 103, 128, 135, 141, 261–263, 276–278, 312, 330, 343, 356, 369, 382, 391–392, 407, 416
   acq_filters.py12192%47
src/mrpro/data/_kdata
   KData.py1121884%108–109, 124, 131, 141, 149, 203–204, 242, 247–248, 267–278
   KDataRemoveOsMixin.py29293%44, 46
   KDataSelectMixin.py19289%48, 63
   KDataSplitMixin.py48394%53, 84, 93
src/mrpro/data/traj_calculators
   KTrajectoryCalculator.py25292%23, 45
   KTrajectoryIsmrmrd.py13285%41, 50
   KTrajectoryPulseq.py29197%54
src/mrpro/operators
   CartesianSamplingOp.py50296%90, 116
   ConstraintsOp.py60297%46, 48
   EndomorphOperator.py65297%228, 234
   FiniteDifferenceOp.py27293%40, 105
   Functional.py71593%20–22, 117, 119
   GridSamplingOp.py136993%72–73, 82–83, 90–91, 94, 96, 98
   LinearOperator.py1711293%55, 91, 190, 220, 261, 270, 278, 287, 295, 320, 418, 423
   LinearOperatorMatrix.py1581690%82, 119, 152, 161, 166, 175–178, 191–194, 203, 215, 304, 331, 359
   MultiIdentityOp.py13285%43, 48
   Operator.py78297%25, 74
   ProximableFunctionalSeparableSum.py39392%50, 103, 110
   SliceProjectionOp.py173895%44, 61, 63, 69, 206, 227, 260, 300
   WaveletOp.py120596%152, 170, 205, 210, 233
   ZeroPadOp.py16194%30
src/mrpro/utils
   filters.py62297%44, 49
   slice_profiles.py46687%20, 36, 113–116, 149
   sliding_window.py34197%34
   split_idx.py10280%43, 47
   summarize_tensorvalues.py11918%20–29
   typing.py181139%8–23
   zero_pad_or_crop.py31681%26, 30, 54, 57, 60, 63
TOTAL473235093% 

Tests Skipped Failures Errors Time
1967 0 💤 0 ❌ 0 🔥 1m 42s ⏱️

Please sign in to comment.