Skip to content

Commit

Permalink
Merge branch 'main' into gradient_checks_only_for_lin_op
Browse files Browse the repository at this point in the history
  • Loading branch information
ckolbPTB committed Nov 27, 2024
2 parents 0c615dc + 82801e6 commit 4eccec5
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/mrpro/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.241112
0.241126
2 changes: 1 addition & 1 deletion src/mrpro/data/DcfData.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def from_traj_voronoi(cls, traj: KTrajectory) -> Self:

if ks_needing_voronoi:
# Handle full dimensions needing voronoi
dcfs.append(smap(dcf_2d3d_voronoi, torch.stack(list(ks_needing_voronoi), -4), 4))
dcfs.append(smap(dcf_2d3d_voronoi, torch.stack(torch.broadcast_tensors(*ks_needing_voronoi), -4), 4))

if dcfs:
# Multiply all dcfs together
Expand Down
18 changes: 18 additions & 0 deletions src/mrpro/data/MoveDataMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,24 @@ def _convert(data: T) -> T:
new.apply_(_convert, memo=memo, recurse=False)
return new

def apply(
self: Self,
function: Callable[[Any], Any] | None = None,
*,
recurse: bool = True,
) -> Self:
"""Apply a function to all children. Returns a new object.
Parameters
----------
function
The function to apply to all fields. None is interpreted as a no-op.
recurse
If True, the function will be applied to all children that are MoveDataMixin instances.
"""
new = self.clone().apply_(function, recurse=recurse)
return new

def apply_(
self: Self,
function: Callable[[Any], Any] | None = None,
Expand Down
13 changes: 13 additions & 0 deletions src/mrpro/data/SpatialDimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
VectorTypes = torch.Tensor
ScalarTypes = int | float
T = TypeVar('T', torch.Tensor, int, float)

# Covariant types, as SpatialDimension is a Container
# and we want, for example, SpatialDimension[int] to also be a SpatialDimension[float]
T_co = TypeVar('T_co', torch.Tensor, int, float, covariant=True)
Expand Down Expand Up @@ -108,6 +109,7 @@ def from_array_zyx(

return SpatialDimension(z, y, x)

# This function is mainly for type hinting and docstring
def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self:
"""Apply a function to each z, y, x (in-place).
Expand All @@ -118,6 +120,17 @@ def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self:
"""
return super(SpatialDimension, self).apply_(function)

# This function is mainly for type hinting and docstring
def apply(self, function: Callable[[T], T] | None = None, **_) -> Self:
"""Apply a function to each z, y, x (returning a new object).
Parameters
----------
function
function to apply
"""
return super(SpatialDimension, self).apply(function)

@property
def zyx(self) -> tuple[T_co, T_co, T_co]:
"""Return a z,y,x tuple."""
Expand Down
6 changes: 3 additions & 3 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
# Broadcast shapes not always needed but also does not hurt
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction

numpoints = [min(img_size, nufft_numpoints) for img_size in self._nufft_im_size]
self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
numpoints=numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
numpoints=numpoints,
kbwidth=nufft_kbwidth,
)
else:
Expand Down
7 changes: 1 addition & 6 deletions src/mrpro/operators/LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,4 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
@property
def H(self) -> LinearOperator: # noqa: N802
"""Adjoint of adjoint operator, i.e. original LinearOperator."""
return self.operator

@property
def gram(self) -> LinearOperator:
"""Gram operator."""
return self._operator.gram.H
return self._operator
16 changes: 15 additions & 1 deletion tests/data/test_dcf_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from einops import repeat
from mrpro.data import DcfData, KTrajectory

from tests import RandomGenerator


def example_traj_rpe(n_kr, n_ka, n_k0, broadcast=True):
"""Create RPE trajectory with uniform angular gap."""
Expand All @@ -17,7 +19,7 @@ def example_traj_rpe(n_kr, n_ka, n_k0, broadcast=True):
return trajectory


def example_traj_spiral_2d(n_kr, n_ki, n_ka, broadcast=True) -> KTrajectory:
def example_traj_spiral_2d(n_kr: int, n_ki: int, n_ka: int, broadcast: bool = True) -> KTrajectory:
"""Create 2D spiral trajectory with n_kr points along each spiral arm, n_ki
turns per spiral arm and n_ka spiral arms."""
ang = repeat(torch.linspace(0, 2 * torch.pi * n_ki, n_kr), 'k0 -> other k2 k1 k0', other=1, k2=1, k1=1)
Expand Down Expand Up @@ -82,3 +84,15 @@ def test_dcf_rpe_traj_voronoi_cuda(n_kr, n_ka, n_k0):
trajectory = example_traj_rpe(n_kr, n_ka, n_k0)
dcf = DcfData.from_traj_voronoi(trajectory.cuda())
assert dcf.data.is_cuda


def test_dcf_broadcast():
"""Test broadcasting within voronoi dcf calculation."""
rng = RandomGenerator(0)
# kx and ky force voronoi calculation and need to be broadcasted
kx = rng.float32_tensor((1, 1, 4, 4))
ky = rng.float32_tensor((1, 4, 1, 4))
kz = torch.zeros(1, 1, 1, 1)
trajectory = KTrajectory(kz, ky, kx)
dcf = DcfData.from_traj_voronoi(trajectory)
assert dcf.data.shape == trajectory.broadcasted_shape
23 changes: 22 additions & 1 deletion tests/data/test_movedatamixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def testchild(attribute, expected_dtype):
assert new.module.module1.weight is new.module.module1.weight, 'shared module parameters should remain shared'


def test_movedatamixin_apply():
def test_movedatamixin_apply_():
"""Tests apply_ method of MoveDataMixin."""
data = B()
# make one of the parameters shared to test memo behavior
Expand All @@ -223,3 +223,24 @@ def multiply_by_2(obj):
torch.testing.assert_close(data.floattensor, original.floattensor * 2)
torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2 * 2)
assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared'


def test_movedatamixin_apply():
"""Tests apply method of MoveDataMixin."""
data = B()
# make one of the parameters shared to test memo behavior
data.child.floattensor2 = data.child.floattensor
original = data.clone()

def multiply_by_2(obj):
if isinstance(obj, torch.Tensor):
return obj * 2
return obj

new = data.apply(multiply_by_2)
torch.testing.assert_close(data.floattensor, original.floattensor)
torch.testing.assert_close(data.child.floattensor2, original.child.floattensor2)
torch.testing.assert_close(new.floattensor, original.floattensor * 2)
torch.testing.assert_close(new.child.floattensor2, original.child.floattensor2 * 2)
assert data.child.floattensor is data.child.floattensor2, 'shared module parameters should remain shared'
assert new is not data, 'new object should be different from the original'
30 changes: 29 additions & 1 deletion tests/data/test_spatial_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_spatial_dimension_broadcasting():


def test_spatial_dimension_apply_():
"""Test apply_ (inplace)"""
"""Test apply_ (in place)"""

def conversion(x: torch.Tensor) -> torch.Tensor:
assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor'
Expand All @@ -115,6 +115,34 @@ def conversion(x: torch.Tensor) -> torch.Tensor:
assert torch.equal(spatial_dimension_inplace.z, z)


def test_spatial_dimension_apply():
"""Test apply (out of place)"""

def conversion(x: torch.Tensor) -> torch.Tensor:
assert isinstance(x, torch.Tensor), 'The argument to the conversion function should be a tensor'
return x.swapaxes(0, 1).square()

xyz = RandomGenerator(0).float32_tensor((1, 2, 3))
spatial_dimension = SpatialDimension.from_array_xyz(xyz.numpy())
spatial_dimension_outofplace = spatial_dimension.apply(conversion)

assert spatial_dimension_outofplace is not spatial_dimension

assert isinstance(spatial_dimension_outofplace.x, torch.Tensor)
assert isinstance(spatial_dimension_outofplace.y, torch.Tensor)
assert isinstance(spatial_dimension_outofplace.z, torch.Tensor)

x, y, z = conversion(xyz).unbind(-1)
assert torch.equal(spatial_dimension_outofplace.x, x)
assert torch.equal(spatial_dimension_outofplace.y, y)
assert torch.equal(spatial_dimension_outofplace.z, z)

x, y, z = xyz.unbind(-1) # original should be unmodified
assert torch.equal(spatial_dimension.x, x)
assert torch.equal(spatial_dimension.y, y)
assert torch.equal(spatial_dimension.z, z)


def test_spatial_dimension_zyx():
"""Test the zyx tuple property"""
z, y, x = (2, 3, 4)
Expand Down
6 changes: 6 additions & 0 deletions tests/operators/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,12 @@ def test_sum_operator_multiple_adjoint():
dotproduct_adjointness_test(linear_op_sum, u, v)


def test_adjoint_of_adjoint():
"""Test that the adjoint of the adjoint is the original operator"""
a = DummyLinearOperator(RandomGenerator(7).complex64_tensor((3, 10)))
assert a.H.H is a


def test_gram_shortcuts():
"""Test that .gram for composition and scalar multiplication results in shortcuts."""

Expand Down

0 comments on commit 4eccec5

Please sign in to comment.