diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index f4254d8d2..c57d1fbfd 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -11,6 +11,7 @@ from mrpro.data.enums import TrajType from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp from mrpro.operators.FastFourierOp import FastFourierOp from mrpro.operators.LinearOperator import LinearOperator @@ -67,12 +68,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): self._nufft_dims.append(dim) if self._fft_dims: - self._fast_fourier_op = FastFourierOp( + self._fast_fourier_op: FastFourierOp | None = FastFourierOp( dim=tuple(self._fft_dims), recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims), encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims), ) - + self._cart_sampling_op: CartesianSamplingOp | None = CartesianSamplingOp( + encoding_matrix=encoding_matrix, traj=traj + ) + else: + self._fast_fourier_op = None + self._cart_sampling_op = None # Find dimensions which require NUFFT if self._nufft_dims: fft_dims_k210 = [ @@ -102,20 +108,23 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega] self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction - self._fwd_nufft_op = KbNufft( + self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - self._adj_nufft_op = KbNufftAdjoint( + self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint( im_size=self._nufft_im_size, grid_size=grid_size, numpoints=nufft_numpoints, kbwidth=nufft_kbwidth, ) - - self._kshape = traj.broadcasted_shape + else: + self._omega: torch.Tensor | None = None + self._fwd_nufft_op = None + self._adj_nufft_op = None + self._kshape = traj.broadcasted_shape @classmethod def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self: @@ -146,11 +155,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil k-space data with shape: (... coils k2 k1 k0) """ - if len(self._fft_dims): - # FFT - (x,) = self._fast_fourier_op(x) - - if self._nufft_dims: + if self._fwd_nufft_op is not None and self._omega is not None: + # NUFFT Type 2 # we need to move the nufft-dimensions to the end and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils nufft_dims # we could move the permute to __init__ but then we still would need to prepend if len(other)>1 @@ -163,7 +169,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.flatten(end_dim=-len(keep_dims) - 1) # omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims) - # TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape). omega = self._omega.permute(*permute) omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :]) omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) @@ -173,6 +178,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims] x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils x = x.permute(*unpermute) + + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: + # FFT + (x,) = self._cart_sampling_op(self._fast_fourier_op(x)[0]) + return (x,) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: @@ -187,11 +197,12 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil image data with shape: (... coils z y x) """ - if self._fft_dims: + if self._fast_fourier_op is not None and self._cart_sampling_op is not None: # IFFT - (x,) = self._fast_fourier_op.adjoint(x) + (x,) = self._fast_fourier_op.adjoint(self._cart_sampling_op.adjoint(x)[0]) - if self._nufft_dims: + if self._adj_nufft_op is not None and self._omega is not None: + # NUFFT Type 1 # we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions # so the new shape will be (... non_nufft_dims) coils (nufft_dims) keep_dims = [-4, *self._nufft_dims] # -4 is coil diff --git a/tests/conftest.py b/tests/conftest.py index 1fb7fb95f..3bd8946f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from xsdata.models.datatype import XmlDate, XmlTime from tests import RandomGenerator +from tests.data import IsmrmrdRawTestData from tests.phantoms import EllipsePhantomTestData @@ -249,6 +250,19 @@ def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): return trajectory +@pytest.fixture(scope='session') +def ismrmrd_cart(ellipse_phantom, tmp_path_factory): + """Fully sampled cartesian data set.""" + ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' + ismrmrd_kdata = IsmrmrdRawTestData( + filename=ismrmrd_filename, + noise_level=0.0, + repetitions=3, + phantom=ellipse_phantom.phantom, + ) + return ismrmrd_kdata + + COMMON_MR_TRAJECTORIES = pytest.mark.parametrize( ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz', 'type_k0', 'type_k1', 'type_k2'), [ diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index ab4f5aabb..d5cfa0f0c 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -16,19 +16,6 @@ from tests.phantoms import EllipsePhantomTestData -@pytest.fixture(scope='session') -def ismrmrd_cart(ellipse_phantom, tmp_path_factory): - """Fully sampled cartesian data set.""" - ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5' - ismrmrd_kdata = IsmrmrdRawTestData( - filename=ismrmrd_filename, - noise_level=0.0, - repetitions=3, - phantom=ellipse_phantom.phantom, - ) - return ismrmrd_kdata - - @pytest.fixture(scope='session') def ismrmrd_cart_bodycoil_and_surface_coil(ellipse_phantom, tmp_path_factory): """Fully sampled cartesian data set with bodycoil and surface coil data.""" diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 6f8c377cf..f48a24260 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -1,7 +1,9 @@ """Tests for Fourier operator.""" import pytest -from mrpro.data import SpatialDimension +import torch +from mrpro.data import KData, KTrajectory, SpatialDimension +from mrpro.data.traj_calculators import KTrajectoryCartesian from mrpro.operators import FourierOp from tests import RandomGenerator @@ -30,7 +32,11 @@ def test_fourier_op_fwd_adj_property( # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) - encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) # apply forward operator @@ -70,3 +76,22 @@ def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'): FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) + + +def test_fourier_op_cartesian_sorting(ismrmrd_cart): + """Verify correct sorting of Cartesian k-space data before FFT.""" + kdata = KData.from_file(ismrmrd_cart.filename, KTrajectoryCartesian()) + ff_op = FourierOp.from_kdata(kdata) + (img,) = ff_op.adjoint(kdata.data) + + # shuffle the kspace points along k0 + permutation_index = torch.randperm(kdata.data.shape[-1]) + kdata_unsorted = KData( + header=kdata.header, + data=kdata.data[..., permutation_index], + traj=KTrajectory.from_tensor(kdata.traj.as_tensor()[..., permutation_index]), + ) + ff_op_unsorted = FourierOp.from_kdata(kdata_unsorted) + (img_unsorted,) = ff_op_unsorted.adjoint(kdata_unsorted.data) + + torch.testing.assert_close(img, img_unsorted)