Skip to content

Commit

Permalink
gram cartesian sampling op
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 committed Nov 9, 2024
1 parent 848a023 commit 3e763e2
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
61 changes: 61 additions & 0 deletions src/mrpro/operators/CartesianSamplingOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,64 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
)

return (y_reshaped,)

@property
def gram(self) -> 'CartesianSamplingGramOp':
"""Return the Gram operator for this Cartesian Sampling Operator.
Returns
-------
Gram operator for this Cartesian Sampling Operator
"""
return CartesianSamplingGramOp(self)


class CartesianSamplingGramOp(LinearOperator):
"""Gram operator for Cartesian Sampling Operator.
The Gram operator is the composition CartesianSamplingOp.H @ CartesianSamplingOp.
"""

def __init__(self, sampling_op: CartesianSamplingOp):
"""Initialize Cartesian Sampling Gram Operator class.
This should not be used directly, but rather through the `gram` method of a
:class:`mrpro.operator.CartesianSamplingOp` object.
Parameters
----------
sampling_op
The Cartesian Sampling Operator for which to create the Gram operator.
"""
super().__init__()
ones = torch.ones(*sampling_op._trajectory_shape[:-3], 1, *sampling_op._sorted_grid_shape.zyx)
(mask,) = sampling_op.adjoint(*sampling_op.forward(ones))
self.register_buffer('_mask', mask)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the Gram operator.
Parameters
----------
x
Input data
Returns
-------
Output data
"""
return (x * self._mask,)

def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the adjoint of the Gram operator.
Parameters
----------
y
Input data
Returns
-------
Output data
"""
return self.forward(y)
55 changes: 55 additions & 0 deletions tests/operators/test_cartesian_sampling_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,58 @@ def test_cart_sampling_op_fwd_adj(sampling):
u = random_generator.complex64_tensor(size=k_shape)
v = random_generator.complex64_tensor(size=k_shape[:2] + trajectory.as_tensor().shape[2:])
dotproduct_adjointness_test(sampling_op, u, v)


@pytest.mark.parametrize(
'sampling',
[
'random',
'partial_echo',
'partial_fourier',
'regular_undersampling',
'random_undersampling',
'different_random_undersampling',
],
)
def test_cart_sampling_op_gram(sampling):
"""Test adjoint gram of Cartesian sampling operator."""

# Create 3D uniform trajectory
k_shape = (2, 5, 20, 40, 60)
nkx = (2, 1, 1, 60)
nky = (2, 1, 40, 1)
nkz = (2, 20, 1, 1)
sx = 'uf'
sy = 'uf'
sz = 'uf'
trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor()

# Subsample data and trajectory
match sampling:
case 'random':
random_idx = torch.randperm(k_shape[-2])
trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx, :])
case 'partial_echo':
trajectory = KTrajectory.from_tensor(trajectory_tensor[..., : k_shape[-1] // 2])
case 'partial_fourier':
trajectory = KTrajectory.from_tensor(trajectory_tensor[..., : k_shape[-3] // 2, : k_shape[-2] // 2, :])
case 'regular_undersampling':
trajectory = KTrajectory.from_tensor(trajectory_tensor[..., ::3, ::5, :])
case 'random_undersampling':
random_idx = torch.randperm(k_shape[-2])
trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: k_shape[-2] // 2], :])
case 'different_random_undersampling':
traj_list = [
traj_one_other[..., torch.randperm(k_shape[-2])[: k_shape[-2] // 2], :]
for traj_one_other in trajectory_tensor.unbind(1)
]
trajectory = KTrajectory.from_tensor(torch.stack(traj_list, dim=1))
case _:
raise NotImplementedError(f'Test {sampling} not implemented.')

encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1])
sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory)
u = RandomGenerator(seed=0).complex64_tensor(size=k_shape)
(expected,) = (sampling_op.H @ sampling_op)(u)
(actual,) = sampling_op.gram(u)
torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3)

0 comments on commit 3e763e2

Please sign in to comment.