Skip to content

Commit

Permalink
Merge branch 'main' into reshape_view
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 authored Dec 3, 2024
2 parents f223b3a + 82801e6 commit 43b2920
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/mrpro/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.241112
0.241126
2 changes: 1 addition & 1 deletion src/mrpro/data/DcfData.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def from_traj_voronoi(cls, traj: KTrajectory) -> Self:

if ks_needing_voronoi:
# Handle full dimensions needing voronoi
dcfs.append(smap(dcf_2d3d_voronoi, torch.stack(list(ks_needing_voronoi), -4), 4))
dcfs.append(smap(dcf_2d3d_voronoi, torch.stack(torch.broadcast_tensors(*ks_needing_voronoi), -4), 4))

if dcfs:
# Multiply all dcfs together
Expand Down
6 changes: 3 additions & 3 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
# Broadcast shapes not always needed but also does not hurt
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction

numpoints = [min(img_size, nufft_numpoints) for img_size in self._nufft_im_size]
self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
numpoints=numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
numpoints=numpoints,
kbwidth=nufft_kbwidth,
)
else:
Expand Down
7 changes: 1 addition & 6 deletions src/mrpro/operators/LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,4 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
@property
def H(self) -> LinearOperator: # noqa: N802
"""Adjoint of adjoint operator, i.e. original LinearOperator."""
return self.operator

@property
def gram(self) -> LinearOperator:
"""Gram operator."""
return self._operator.gram.H
return self._operator
16 changes: 15 additions & 1 deletion tests/data/test_dcf_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from einops import repeat
from mrpro.data import DcfData, KTrajectory

from tests import RandomGenerator


def example_traj_rpe(n_kr, n_ka, n_k0, broadcast=True):
"""Create RPE trajectory with uniform angular gap."""
Expand All @@ -17,7 +19,7 @@ def example_traj_rpe(n_kr, n_ka, n_k0, broadcast=True):
return trajectory


def example_traj_spiral_2d(n_kr, n_ki, n_ka, broadcast=True) -> KTrajectory:
def example_traj_spiral_2d(n_kr: int, n_ki: int, n_ka: int, broadcast: bool = True) -> KTrajectory:
"""Create 2D spiral trajectory with n_kr points along each spiral arm, n_ki
turns per spiral arm and n_ka spiral arms."""
ang = repeat(torch.linspace(0, 2 * torch.pi * n_ki, n_kr), 'k0 -> other k2 k1 k0', other=1, k2=1, k1=1)
Expand Down Expand Up @@ -82,3 +84,15 @@ def test_dcf_rpe_traj_voronoi_cuda(n_kr, n_ka, n_k0):
trajectory = example_traj_rpe(n_kr, n_ka, n_k0)
dcf = DcfData.from_traj_voronoi(trajectory.cuda())
assert dcf.data.is_cuda


def test_dcf_broadcast():
"""Test broadcasting within voronoi dcf calculation."""
rng = RandomGenerator(0)
# kx and ky force voronoi calculation and need to be broadcasted
kx = rng.float32_tensor((1, 1, 4, 4))
ky = rng.float32_tensor((1, 4, 1, 4))
kz = torch.zeros(1, 1, 1, 1)
trajectory = KTrajectory(kz, ky, kx)
dcf = DcfData.from_traj_voronoi(trajectory)
assert dcf.data.shape == trajectory.broadcasted_shape
6 changes: 6 additions & 0 deletions tests/operators/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,12 @@ def test_sum_operator_multiple_adjoint():
dotproduct_adjointness_test(linear_op_sum, u, v)


def test_adjoint_of_adjoint():
"""Test that the adjoint of the adjoint is the original operator"""
a = DummyLinearOperator(RandomGenerator(7).complex64_tensor((3, 10)))
assert a.H.H is a


def test_gram_shortcuts():
"""Test that .gram for composition and scalar multiplication results in shortcuts."""

Expand Down

0 comments on commit 43b2920

Please sign in to comment.