diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 64068a5d..e1ef60e1 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -140,3 +140,64 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: ) return (y_reshaped,) + + @property + def gram(self) -> 'CartesianSamplingGramOp': + """Return the Gram operator for this Cartesian Sampling Operator. + + Returns + ------- + Gram operator for this Cartesian Sampling Operator + """ + return CartesianSamplingGramOp(self) + + +class CartesianSamplingGramOp(LinearOperator): + """Gram operator for Cartesian Sampling Operator. + + The Gram operator is the composition CartesianSamplingOp.H @ CartesianSamplingOp. + """ + + def __init__(self, sampling_op: CartesianSamplingOp): + """Initialize Cartesian Sampling Gram Operator class. + + This should not be used directly, but rather through the `gram` method of a + :class:`mrpro.operator.CartesianSamplingOp` object. + + Parameters + ---------- + sampling_op + The Cartesian Sampling Operator for which to create the Gram operator. + """ + super().__init__() + ones = torch.ones(*sampling_op._trajectory_shape[:-3], 1, *sampling_op._sorted_grid_shape.zyx) + (mask,) = sampling_op.adjoint(*sampling_op.forward(ones)) + self.register_buffer('_mask', mask) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the Gram operator. + + Parameters + ---------- + x + Input data + + Returns + ------- + Output data + """ + return (x * self._mask,) + + def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint of the Gram operator. + + Parameters + ---------- + y + Input data + + Returns + ------- + Output data + """ + return self.forward(y) diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 6a1120e7..3de875ff 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -105,3 +105,58 @@ def test_cart_sampling_op_fwd_adj(sampling): u = random_generator.complex64_tensor(size=k_shape) v = random_generator.complex64_tensor(size=k_shape[:2] + trajectory.as_tensor().shape[2:]) dotproduct_adjointness_test(sampling_op, u, v) + + +@pytest.mark.parametrize( + 'sampling', + [ + 'random', + 'partial_echo', + 'partial_fourier', + 'regular_undersampling', + 'random_undersampling', + 'different_random_undersampling', + ], +) +def test_cart_sampling_op_gram(sampling): + """Test adjoint gram of Cartesian sampling operator.""" + + # Create 3D uniform trajectory + k_shape = (2, 5, 20, 40, 60) + nkx = (2, 1, 1, 60) + nky = (2, 1, 40, 1) + nkz = (2, 20, 1, 1) + sx = 'uf' + sy = 'uf' + sz = 'uf' + trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor() + + # Subsample data and trajectory + match sampling: + case 'random': + random_idx = torch.randperm(k_shape[-2]) + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx, :]) + case 'partial_echo': + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., : k_shape[-1] // 2]) + case 'partial_fourier': + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., : k_shape[-3] // 2, : k_shape[-2] // 2, :]) + case 'regular_undersampling': + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., ::3, ::5, :]) + case 'random_undersampling': + random_idx = torch.randperm(k_shape[-2]) + trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: k_shape[-2] // 2], :]) + case 'different_random_undersampling': + traj_list = [ + traj_one_other[..., torch.randperm(k_shape[-2])[: k_shape[-2] // 2], :] + for traj_one_other in trajectory_tensor.unbind(1) + ] + trajectory = KTrajectory.from_tensor(torch.stack(traj_list, dim=1)) + case _: + raise NotImplementedError(f'Test {sampling} not implemented.') + + encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) + u = RandomGenerator(seed=0).complex64_tensor(size=k_shape) + (expected,) = (sampling_op.H @ sampling_op)(u) + (actual,) = sampling_op.gram(u) + torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3)