Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude data outside of encoding_matrix #234

Merged
merged 10 commits into from
Nov 12, 2024
98 changes: 84 additions & 14 deletions src/mrpro/operators/CartesianSamplingOp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Cartesian Sampling Operator."""

import warnings

import torch
from einops import rearrange, repeat

from mrpro.data.enums import TrajType
from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.utils.reshape import unsqueeze_left


class CartesianSamplingOp(LinearOperator):
Expand Down Expand Up @@ -64,10 +67,35 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) ->
# 1D indices into a flattened tensor.
kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx
kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)')

# check that all points are inside the encoding matrix
inside_encoding_matrix = (
((kx_idx >= 0) & (kx_idx < sorted_grid_shape.x))
& ((ky_idx >= 0) & (ky_idx < sorted_grid_shape.y))
& ((kz_idx >= 0) & (kz_idx < sorted_grid_shape.z))
)
if not torch.all(inside_encoding_matrix):
warnings.warn(
'K-space points lie outside of the encoding_matrix and will be ignored.'
' Increase the encoding_matrix to include these points.',
stacklevel=2,
)

inside_encoding_matrix = rearrange(inside_encoding_matrix, '... kz ky kx -> ... 1 (kz ky kx)')
inside_encoding_matrix_idx = inside_encoding_matrix.nonzero(as_tuple=True)[-1]
inside_encoding_matrix_idx = torch.reshape(inside_encoding_matrix_idx, (*kidx.shape[:-1], -1))
self.register_buffer('_inside_encoding_matrix_idx', inside_encoding_matrix_idx)
kidx = torch.take_along_dim(kidx, inside_encoding_matrix_idx, dim=-1)
else:
self._inside_encoding_matrix_idx: torch.Tensor | None = None

self.register_buffer('_fft_idx', kidx)

# we can skip the indexing if the data is already sorted
self._needs_indexing = (
not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx
not torch.all(torch.diff(kidx) == 1)
or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx
or self._inside_encoding_matrix_idx is not None
)

self._trajectory_shape = traj.broadcasted_shape
Expand All @@ -93,8 +121,21 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
return (x,)

x_kflat = rearrange(x, '... coil k2_enc k1_enc k0_enc -> ... coil (k2_enc k1_enc k0_enc)')
# take_along_dim does broadcast, so no need for extending here
x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1)
# take_along_dim broadcasts, but needs the same number of dimensions
idx = unsqueeze_left(self._fft_idx, x_kflat.ndim - self._fft_idx.ndim)
x_inside_encoding_matrix = torch.take_along_dim(x_kflat, idx, dim=-1)

if self._inside_encoding_matrix_idx is None:
# all trajectory points are inside the encoding matrix
x_indexed = x_inside_encoding_matrix
else:
# we need to add zeros
x_indexed = self._broadcast_and_scatter_along_last_dim(
x_inside_encoding_matrix,
self._trajectory_shape[-1] * self._trajectory_shape[-2] * self._trajectory_shape[-3],
self._inside_encoding_matrix_idx,
)

# reshape to (... other coil, k2, k1, k0)
x_reshaped = x_indexed.reshape(x.shape[:-3] + self._trajectory_shape[-3:])

Expand All @@ -120,18 +161,13 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:

y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)')

# scatter does not broadcast, so we need to manually broadcast the indices
broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1])
idx_expanded = torch.broadcast_to(self._fft_idx, (*broadcast_shape, self._fft_idx.shape[-1]))
if self._inside_encoding_matrix_idx is not None:
idx = unsqueeze_left(self._inside_encoding_matrix_idx, y_kflat.ndim - self._inside_encoding_matrix_idx.ndim)
y_kflat = torch.take_along_dim(y_kflat, idx, dim=-1)

# although scatter_ is inplace, this will not cause issues with autograd, as self
# is always constant zero and gradients w.r.t. src work as expected.
y_scattered = torch.zeros(
*broadcast_shape,
self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x,
dtype=y.dtype,
device=y.device,
).scatter_(dim=-1, index=idx_expanded, src=y_kflat)
y_scattered = self._broadcast_and_scatter_along_last_dim(
y_kflat, self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, self._fft_idx
)

# reshape to ..., other, coil, k2_enc, k1_enc, k0_enc
y_reshaped = y_scattered.reshape(
Expand All @@ -142,3 +178,37 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
)

return (y_reshaped,)

@staticmethod
def _broadcast_and_scatter_along_last_dim(
data_to_scatter: torch.Tensor, n_last_dim: int, scatter_index: torch.Tensor
) -> torch.Tensor:
"""Broadcast scatter index and scatter into zero tensor.

Parameters
----------
data_to_scatter
Data to be scattered at indices scatter_index
n_last_dim
Number of data points in last dimension
scatter_index
Indices describing where to scatter data

Returns
-------
Data scattered into tensor along scatter_index
"""
# scatter does not broadcast, so we need to manually broadcast the indices
broadcast_shape = torch.broadcast_shapes(scatter_index.shape[:-1], data_to_scatter.shape[:-1])
idx_expanded = torch.broadcast_to(scatter_index, (*broadcast_shape, scatter_index.shape[-1]))

# although scatter_ is inplace, this will not cause issues with autograd, as self
# is always constant zero and gradients w.r.t. src work as expected.
data_scattered = torch.zeros(
*broadcast_shape,
n_last_dim,
dtype=data_to_scatter.dtype,
device=data_to_scatter.device,
).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter)

return data_scattered
25 changes: 25 additions & 0 deletions tests/operators/test_cartesian_sampling_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,28 @@ def test_cart_sampling_op_fwd_adj(sampling):
u = random_generator.complex64_tensor(size=k_shape)
v = random_generator.complex64_tensor(size=k_shape[:2] + trajectory.as_tensor().shape[2:])
dotproduct_adjointness_test(sampling_op, u, v)


@pytest.mark.parametrize(('k2_min', 'k2_max'), [(-1, 21), (-21, 1)])
@pytest.mark.parametrize(('k0_min', 'k0_max'), [(-6, 13), (-13, 6)])
def test_cart_sampling_op_oversampling(k0_min, k0_max, k2_min, k2_max):
"""Test trajectory points outside of encoding_matrix."""
encoding_matrix = SpatialDimension(40, 1, 20)

# Create kx and kz sampling which are asymmetric and larger than the encoding matrix on one side
# The indices are inverted to ensure CartesianSamplingOp acts on them
kx = rearrange(torch.linspace(k0_max, k0_min, 20), 'kx->1 1 1 kx')
ky = torch.ones(1, 1, 1, 1)
kz = rearrange(torch.linspace(k2_max, k2_min, 40), 'kz-> kz 1 1')
kz = torch.stack([kz, -kz], dim=0) # different kz values for two other elements
trajectory = KTrajectory(kz=kz, ky=ky, kx=kx)

with pytest.warns(UserWarning, match='K-space points lie outside of the encoding_matrix'):
sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory)

random_generator = RandomGenerator(seed=0)
u = random_generator.complex64_tensor(size=(3, 2, 5, kz.shape[-3], ky.shape[-2], kx.shape[-1]))
v = random_generator.complex64_tensor(size=(3, 2, 5, *encoding_matrix.zyx))

assert sampling_op.adjoint(u)[0].shape[-3:] == encoding_matrix.zyx
assert sampling_op(v)[0].shape[-3:] == (kz.shape[-3], ky.shape[-2], kx.shape[-1])
Loading