diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index 502e9fce..a86bb4b3 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -1,8 +1,11 @@ """Fourier Operator.""" from collections.abc import Sequence +from itertools import product +import numpy as np import torch +from torchkbnufft import KbNufftAdjoint from typing_extensions import Self from mrpro.data._kdata.KData import KData @@ -127,7 +130,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil k-space data with shape: (... coils k2 k1 k0) """ - # FFT followed by NUFFT + # FFT followed by NUFFT Type 2 return self._non_uniform_fast_fourier_op(self._fast_fourier_op(x)[0]) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: @@ -142,5 +145,160 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: ------- coil image data with shape: (... coils z y x) """ - # NUFFT followed by FFT + # NUFFT Type 1 followed by FFT return self._fast_fourier_op.adjoint(self._non_uniform_fast_fourier_op.adjoint(x)[0]) + + +def symmetrize(kernel: torch.Tensor, rank: int) -> torch.Tensor: + """Enforce hermitian symmetry on the kernel. Returns only half of the kernel.""" + flipped = kernel.clone() + for d in range(-rank, 0): + flipped = flipped.index_select(d, -1 * torch.arange(flipped.shape[d], device=flipped.device) % flipped.size(d)) + kernel = (kernel + flipped.conj()) / 2 + last_len = kernel.shape[-1] + return kernel[..., : last_len // 2 + 1] + + +def gram_nufft_kernel(weight: torch.Tensor, trajectory: torch.Tensor, recon_shape: Sequence[int]) -> torch.Tensor: + """Calculate the convolution kernel for the NUFFT gram operator. + + Parameters + ---------- + weight + either ones or density compensation weights + trajectory + k-space trajectory + recon_shape + shape of the reconstructed image + + Returns + ------- + kernel + real valued convolution kernel for the NUFFT gram operator, already in Fourier space + """ + rank = trajectory.shape[-2] + if rank != len(recon_shape): + raise ValueError('Rank of trajectory and image size must match.') + # Instead of doing one adjoint nufft with double the recon size in all dimensions, + # we do two adjoint nuffts per dimensions, saving a lot of memory. + adjnufft_ob = KbNufftAdjoint(im_size=recon_shape, n_shift=[0] * rank).to(trajectory) + + kernel = adjnufft_ob(weight, trajectory) # this will be the top left ... corner block + pad = [] + for s in kernel.shape[: -rank - 1 : -1]: + pad.extend([0, s]) + kernel = torch.nn.functional.pad(kernel, pad) # twice the size in all dimensions + + for flips in list(product([1, -1], repeat=rank)): + if all(flip == 1 for flip in flips): + # top left ... block already processed before padding + continue + flipped_trajectory = trajectory * torch.tensor(flips).to(trajectory).unsqueeze(-1) + kernel_part = adjnufft_ob(weight, flipped_trajectory) + slices = [] # which part of the kernel to is currently being processed + for dim, flip in zip(range(-rank, 0), flips, strict=True): + if flip > 0: # first half in the dimension + slices.append(slice(0, kernel_part.size(dim))) + else: # second half in the dimension + slices.append(slice(kernel_part.size(dim) + 1, None)) + kernel_part = kernel_part.index_select(dim, torch.arange(kernel_part.size(dim) - 1, 0, -1)) # flip + + kernel[[..., *slices]] = kernel_part + + kernel = symmetrize(kernel, rank) + kernel = torch.fft.hfftn(kernel, dim=list(range(-rank, 0)), norm='backward') + kernel /= kernel.shape[-rank:].numel() + kernel = torch.fft.fftshift(kernel, dim=list(range(-rank, 0))) + return kernel + + +class FourierGramOp(LinearOperator): + """Gram operator for the Fourier operator. + + Implements the adjoint of the forward operator of the Fourier operator, i.e. the gram operator + `F.H@F. + + Uses a convolution, implemented as multiplication in Fourier space, to calculate the gram operator + for the toeplitz NUFFT operator. + + Uses a multiplication with a binary mask in Fourier space to calculate the gram operator for + the Cartesian FFT operator + + This Operator is only used internally and should not be used directly. + Instead, consider using the `gram` property of :class: `mrpro.operators.FourierOp`. + """ + + _kernel: torch.Tensor | None + + def __init__(self, fourier_op: FourierOp) -> None: + """Initialize the gram operator. + + If density compensation weights are provided, they the operator + F.H@dcf@F is calculated. + + Parameters + ---------- + fourier_op + the Fourier operator to calculate the gram operator for + + """ + super().__init__() + if fourier_op._nufft_dims and fourier_op._omega is not None: + weight = torch.ones_like(fourier_op._omega[..., :1, :, :, :]) + keep_dims = [-4, *fourier_op._nufft_dims] # -4 is coil + permute = [i for i in range(-weight.ndim, 0) if i not in keep_dims] + keep_dims + unpermute = np.argsort(permute) + weight = weight.permute(*permute) + weight_unflattend_shape = weight.shape + weight = weight.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) + weight = weight + 0j + omega = fourier_op._omega.permute(*permute) + omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) + kernel = gram_nufft_kernel(weight, omega, fourier_op._nufft_im_size) + kernel = kernel.reshape(*weight_unflattend_shape[: -len(keep_dims)], *kernel.shape[-len(keep_dims) :]) + kernel = kernel.permute(*unpermute) + fft = FastFourierOp( + dim=fourier_op._nufft_dims, + encoding_matrix=[2 * s for s in fourier_op._nufft_im_size], + recon_matrix=fourier_op._nufft_im_size, + ) + self.nufft_gram: None | LinearOperator = fft.H * kernel @ fft + else: + self.nufft_gram = None + + if fourier_op._fast_fourier_op is not None and fourier_op._cart_sampling_op is not None: + self.fast_fourier_gram: None | LinearOperator = ( + fourier_op._fast_fourier_op.H @ fourier_op._cart_sampling_op.gram @ fourier_op._fast_fourier_op + ) + else: + self.fast_fourier_gram = None + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the operator to the input tensor. + + Parameters + ---------- + x + input tensor, shape (..., coils, z, y, x) + """ + if self.nufft_gram is not None: + (x,) = self.nufft_gram(x) + + if self.fast_fourier_gram is not None: + (x,) = self.fast_fourier_gram(x) + return (x,) + + def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint operator to the input tensor. + + Parameters + ---------- + x + input tensor, shape (..., coils, k2, k1, k0) + """ + return self.forward(x) + + @property + def H(self) -> Self: # noqa: N802 + """Adjoint operator of the gram operator.""" + return self diff --git a/src/mrpro/operators/NonUniformFastFourierOp.py b/src/mrpro/operators/NonUniformFastFourierOp.py index a3df4241..ac92437c 100644 --- a/src/mrpro/operators/NonUniformFastFourierOp.py +++ b/src/mrpro/operators/NonUniformFastFourierOp.py @@ -63,17 +63,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): # Broadcast shapes not always needed but also does not hurt omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega] self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction - + numpoints = [min(img_size, nufft_numpoints) for img_size in recon_matrix] self._fwd_nufft_op = KbNufft( im_size=recon_matrix, grid_size=grid_size, - numpoints=nufft_numpoints, + numpoints=numpoints, kbwidth=nufft_kbwidth, ) self._adj_nufft_op = KbNufftAdjoint( im_size=recon_matrix, grid_size=grid_size, - numpoints=nufft_numpoints, + numpoints=numpoints, kbwidth=nufft_kbwidth, ) diff --git a/tests/operators/test_non_uniform_fast_fourier_op.py b/tests/operators/test_non_uniform_fast_fourier_op.py index e3c5b4d1..49c9cb7e 100644 --- a/tests/operators/test_non_uniform_fast_fourier_op.py +++ b/tests/operators/test_non_uniform_fast_fourier_op.py @@ -11,13 +11,13 @@ from tests.helper import dotproduct_adjointness_test, relative_image_difference -def create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz): +def create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): random_generator = RandomGenerator(seed=0) # generate random image img = random_generator.complex64_tensor(size=im_shape) # create random trajectories - trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz) + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) return img, trajectory @@ -31,7 +31,9 @@ def test_non_uniform_fast_fourier_op_fwd_adj_property(dim): # generate random traj nk = [kdata_shape[d] if d in dim else 1 for d in (-5, -3, -2, -1)] # skip coil dimension - traj = create_traj(kdata_shape, nkx=nk, nky=nk, nkz=nk, sx='nuf', sy='nuf', sz='nuf') + traj = create_traj( + kdata_shape, nkx=nk, nky=nk, nkz=nk, type_kx='non-uniform', type_ky='non-uniform', type_kz='non-uniform' + ) # create operator nufft_op = NonUniformFastFourierOp( @@ -73,7 +75,7 @@ def test_non_uniform_fast_fourier_op_equal_to_fft(ismrmrd_cart): def test_non_uniform_fast_fourier_op_empty_dims(): """Empty dims do not change the input.""" nk = [1, 1, 1, 1, 1] - traj = create_traj(nk, nkx=nk, nky=nk, nkz=nk, sx='nuf', sy='nuf', sz='nuf') + traj = create_traj(nk, nkx=nk, nky=nk, nkz=nk, type_kx='non-uniform', type_ky='non-uniform', type_kz='non-uniform') nufft_op = NonUniformFastFourierOp(dim=(), recon_matrix=(), encoding_matrix=(), traj=traj) random_generator = RandomGenerator(seed=0)