Skip to content

Commit

Permalink
fix merge problems
Browse files Browse the repository at this point in the history
  • Loading branch information
ckolbPTB committed Dec 5, 2024
1 parent a4b8a64 commit 2a90533
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 9 deletions.
162 changes: 160 additions & 2 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Fourier Operator."""

from collections.abc import Sequence
from itertools import product

import numpy as np
import torch
from torchkbnufft import KbNufftAdjoint
from typing_extensions import Self

from mrpro.data._kdata.KData import KData
Expand Down Expand Up @@ -127,7 +130,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
-------
coil k-space data with shape: (... coils k2 k1 k0)
"""
# FFT followed by NUFFT
# FFT followed by NUFFT Type 2
return self._non_uniform_fast_fourier_op(self._fast_fourier_op(x)[0])

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
Expand All @@ -142,5 +145,160 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
-------
coil image data with shape: (... coils z y x)
"""
# NUFFT followed by FFT
# NUFFT Type 1 followed by FFT
return self._fast_fourier_op.adjoint(self._non_uniform_fast_fourier_op.adjoint(x)[0])


def symmetrize(kernel: torch.Tensor, rank: int) -> torch.Tensor:
"""Enforce hermitian symmetry on the kernel. Returns only half of the kernel."""
flipped = kernel.clone()
for d in range(-rank, 0):
flipped = flipped.index_select(d, -1 * torch.arange(flipped.shape[d], device=flipped.device) % flipped.size(d))
kernel = (kernel + flipped.conj()) / 2
last_len = kernel.shape[-1]
return kernel[..., : last_len // 2 + 1]


def gram_nufft_kernel(weight: torch.Tensor, trajectory: torch.Tensor, recon_shape: Sequence[int]) -> torch.Tensor:
"""Calculate the convolution kernel for the NUFFT gram operator.
Parameters
----------
weight
either ones or density compensation weights
trajectory
k-space trajectory
recon_shape
shape of the reconstructed image
Returns
-------
kernel
real valued convolution kernel for the NUFFT gram operator, already in Fourier space
"""
rank = trajectory.shape[-2]
if rank != len(recon_shape):
raise ValueError('Rank of trajectory and image size must match.')
# Instead of doing one adjoint nufft with double the recon size in all dimensions,
# we do two adjoint nuffts per dimensions, saving a lot of memory.
adjnufft_ob = KbNufftAdjoint(im_size=recon_shape, n_shift=[0] * rank).to(trajectory)

kernel = adjnufft_ob(weight, trajectory) # this will be the top left ... corner block
pad = []
for s in kernel.shape[: -rank - 1 : -1]:
pad.extend([0, s])
kernel = torch.nn.functional.pad(kernel, pad) # twice the size in all dimensions

for flips in list(product([1, -1], repeat=rank)):
if all(flip == 1 for flip in flips):
# top left ... block already processed before padding
continue
flipped_trajectory = trajectory * torch.tensor(flips).to(trajectory).unsqueeze(-1)
kernel_part = adjnufft_ob(weight, flipped_trajectory)
slices = [] # which part of the kernel to is currently being processed
for dim, flip in zip(range(-rank, 0), flips, strict=True):
if flip > 0: # first half in the dimension
slices.append(slice(0, kernel_part.size(dim)))
else: # second half in the dimension
slices.append(slice(kernel_part.size(dim) + 1, None))
kernel_part = kernel_part.index_select(dim, torch.arange(kernel_part.size(dim) - 1, 0, -1)) # flip

kernel[[..., *slices]] = kernel_part

kernel = symmetrize(kernel, rank)
kernel = torch.fft.hfftn(kernel, dim=list(range(-rank, 0)), norm='backward')
kernel /= kernel.shape[-rank:].numel()
kernel = torch.fft.fftshift(kernel, dim=list(range(-rank, 0)))
return kernel


class FourierGramOp(LinearOperator):
"""Gram operator for the Fourier operator.
Implements the adjoint of the forward operator of the Fourier operator, i.e. the gram operator
`F.H@F.
Uses a convolution, implemented as multiplication in Fourier space, to calculate the gram operator
for the toeplitz NUFFT operator.
Uses a multiplication with a binary mask in Fourier space to calculate the gram operator for
the Cartesian FFT operator
This Operator is only used internally and should not be used directly.
Instead, consider using the `gram` property of :class: `mrpro.operators.FourierOp`.
"""

_kernel: torch.Tensor | None

def __init__(self, fourier_op: FourierOp) -> None:
"""Initialize the gram operator.
If density compensation weights are provided, they the operator
F.H@dcf@F is calculated.
Parameters
----------
fourier_op
the Fourier operator to calculate the gram operator for
"""
super().__init__()
if fourier_op._nufft_dims and fourier_op._omega is not None:
weight = torch.ones_like(fourier_op._omega[..., :1, :, :, :])
keep_dims = [-4, *fourier_op._nufft_dims] # -4 is coil
permute = [i for i in range(-weight.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)
weight = weight.permute(*permute)
weight_unflattend_shape = weight.shape
weight = weight.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
weight = weight + 0j
omega = fourier_op._omega.permute(*permute)
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
kernel = gram_nufft_kernel(weight, omega, fourier_op._nufft_im_size)
kernel = kernel.reshape(*weight_unflattend_shape[: -len(keep_dims)], *kernel.shape[-len(keep_dims) :])
kernel = kernel.permute(*unpermute)
fft = FastFourierOp(
dim=fourier_op._nufft_dims,
encoding_matrix=[2 * s for s in fourier_op._nufft_im_size],
recon_matrix=fourier_op._nufft_im_size,
)
self.nufft_gram: None | LinearOperator = fft.H * kernel @ fft
else:
self.nufft_gram = None

if fourier_op._fast_fourier_op is not None and fourier_op._cart_sampling_op is not None:
self.fast_fourier_gram: None | LinearOperator = (
fourier_op._fast_fourier_op.H @ fourier_op._cart_sampling_op.gram @ fourier_op._fast_fourier_op
)
else:
self.fast_fourier_gram = None

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the operator to the input tensor.
Parameters
----------
x
input tensor, shape (..., coils, z, y, x)
"""
if self.nufft_gram is not None:
(x,) = self.nufft_gram(x)

if self.fast_fourier_gram is not None:
(x,) = self.fast_fourier_gram(x)
return (x,)

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the adjoint operator to the input tensor.
Parameters
----------
x
input tensor, shape (..., coils, k2, k1, k0)
"""
return self.forward(x)

@property
def H(self) -> Self: # noqa: N802
"""Adjoint operator of the gram operator."""
return self
6 changes: 3 additions & 3 deletions src/mrpro/operators/NonUniformFastFourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
# Broadcast shapes not always needed but also does not hurt
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction

numpoints = [min(img_size, nufft_numpoints) for img_size in recon_matrix]
self._fwd_nufft_op = KbNufft(
im_size=recon_matrix,
grid_size=grid_size,
numpoints=nufft_numpoints,
numpoints=numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op = KbNufftAdjoint(
im_size=recon_matrix,
grid_size=grid_size,
numpoints=nufft_numpoints,
numpoints=numpoints,
kbwidth=nufft_kbwidth,
)

Expand Down
10 changes: 6 additions & 4 deletions tests/operators/test_non_uniform_fast_fourier_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from tests.helper import dotproduct_adjointness_test, relative_image_difference


def create_data(im_shape, k_shape, nkx, nky, nkz, sx, sy, sz):
def create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz):
random_generator = RandomGenerator(seed=0)

# generate random image
img = random_generator.complex64_tensor(size=im_shape)
# create random trajectories
trajectory = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz)
trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz)
return img, trajectory


Expand All @@ -31,7 +31,9 @@ def test_non_uniform_fast_fourier_op_fwd_adj_property(dim):

# generate random traj
nk = [kdata_shape[d] if d in dim else 1 for d in (-5, -3, -2, -1)] # skip coil dimension
traj = create_traj(kdata_shape, nkx=nk, nky=nk, nkz=nk, sx='nuf', sy='nuf', sz='nuf')
traj = create_traj(
kdata_shape, nkx=nk, nky=nk, nkz=nk, type_kx='non-uniform', type_ky='non-uniform', type_kz='non-uniform'
)

# create operator
nufft_op = NonUniformFastFourierOp(
Expand Down Expand Up @@ -73,7 +75,7 @@ def test_non_uniform_fast_fourier_op_equal_to_fft(ismrmrd_cart):
def test_non_uniform_fast_fourier_op_empty_dims():
"""Empty dims do not change the input."""
nk = [1, 1, 1, 1, 1]
traj = create_traj(nk, nkx=nk, nky=nk, nkz=nk, sx='nuf', sy='nuf', sz='nuf')
traj = create_traj(nk, nkx=nk, nky=nk, nkz=nk, type_kx='non-uniform', type_ky='non-uniform', type_kz='non-uniform')
nufft_op = NonUniformFastFourierOp(dim=(), recon_matrix=(), encoding_matrix=(), traj=traj)

random_generator = RandomGenerator(seed=0)
Expand Down

0 comments on commit 2a90533

Please sign in to comment.