Skip to content

Commit

Permalink
Exclude data outside of encoding_matrix (#234)
Browse files Browse the repository at this point in the history
Co-authored-by: Felix Zimmermann <fzimmermann89@gmail.com>
  • Loading branch information
ckolbPTB and fzimmermann89 authored Nov 12, 2024
1 parent 72c1807 commit 202d395
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 16 deletions.
98 changes: 84 additions & 14 deletions src/mrpro/operators/CartesianSamplingOp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Cartesian Sampling Operator."""

import warnings

import torch
from einops import rearrange, repeat

from mrpro.data.enums import TrajType
from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.utils.reshape import unsqueeze_left


class CartesianSamplingOp(LinearOperator):
Expand Down Expand Up @@ -64,10 +67,35 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) ->
# 1D indices into a flattened tensor.
kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx
kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)')

# check that all points are inside the encoding matrix
inside_encoding_matrix = (
((kx_idx >= 0) & (kx_idx < sorted_grid_shape.x))
& ((ky_idx >= 0) & (ky_idx < sorted_grid_shape.y))
& ((kz_idx >= 0) & (kz_idx < sorted_grid_shape.z))
)
if not torch.all(inside_encoding_matrix):
warnings.warn(
'K-space points lie outside of the encoding_matrix and will be ignored.'
' Increase the encoding_matrix to include these points.',
stacklevel=2,
)

inside_encoding_matrix = rearrange(inside_encoding_matrix, '... kz ky kx -> ... 1 (kz ky kx)')
inside_encoding_matrix_idx = inside_encoding_matrix.nonzero(as_tuple=True)[-1]
inside_encoding_matrix_idx = torch.reshape(inside_encoding_matrix_idx, (*kidx.shape[:-1], -1))
self.register_buffer('_inside_encoding_matrix_idx', inside_encoding_matrix_idx)
kidx = torch.take_along_dim(kidx, inside_encoding_matrix_idx, dim=-1)
else:
self._inside_encoding_matrix_idx: torch.Tensor | None = None

self.register_buffer('_fft_idx', kidx)

# we can skip the indexing if the data is already sorted
self._needs_indexing = (
not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx
not torch.all(torch.diff(kidx) == 1)
or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx
or self._inside_encoding_matrix_idx is not None
)

self._trajectory_shape = traj.broadcasted_shape
Expand All @@ -93,8 +121,21 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
return (x,)

x_kflat = rearrange(x, '... coil k2_enc k1_enc k0_enc -> ... coil (k2_enc k1_enc k0_enc)')
# take_along_dim does broadcast, so no need for extending here
x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1)
# take_along_dim broadcasts, but needs the same number of dimensions
idx = unsqueeze_left(self._fft_idx, x_kflat.ndim - self._fft_idx.ndim)
x_inside_encoding_matrix = torch.take_along_dim(x_kflat, idx, dim=-1)

if self._inside_encoding_matrix_idx is None:
# all trajectory points are inside the encoding matrix
x_indexed = x_inside_encoding_matrix
else:
# we need to add zeros
x_indexed = self._broadcast_and_scatter_along_last_dim(
x_inside_encoding_matrix,
self._trajectory_shape[-1] * self._trajectory_shape[-2] * self._trajectory_shape[-3],
self._inside_encoding_matrix_idx,
)

# reshape to (... other coil, k2, k1, k0)
x_reshaped = x_indexed.reshape(x.shape[:-3] + self._trajectory_shape[-3:])

Expand All @@ -120,18 +161,13 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:

y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)')

# scatter does not broadcast, so we need to manually broadcast the indices
broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1])
idx_expanded = torch.broadcast_to(self._fft_idx, (*broadcast_shape, self._fft_idx.shape[-1]))
if self._inside_encoding_matrix_idx is not None:
idx = unsqueeze_left(self._inside_encoding_matrix_idx, y_kflat.ndim - self._inside_encoding_matrix_idx.ndim)
y_kflat = torch.take_along_dim(y_kflat, idx, dim=-1)

# although scatter_ is inplace, this will not cause issues with autograd, as self
# is always constant zero and gradients w.r.t. src work as expected.
y_scattered = torch.zeros(
*broadcast_shape,
self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x,
dtype=y.dtype,
device=y.device,
).scatter_(dim=-1, index=idx_expanded, src=y_kflat)
y_scattered = self._broadcast_and_scatter_along_last_dim(
y_kflat, self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, self._fft_idx
)

# reshape to ..., other, coil, k2_enc, k1_enc, k0_enc
y_reshaped = y_scattered.reshape(
Expand All @@ -142,3 +178,37 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
)

return (y_reshaped,)

@staticmethod
def _broadcast_and_scatter_along_last_dim(
data_to_scatter: torch.Tensor, n_last_dim: int, scatter_index: torch.Tensor
) -> torch.Tensor:
"""Broadcast scatter index and scatter into zero tensor.
Parameters
----------
data_to_scatter
Data to be scattered at indices scatter_index
n_last_dim
Number of data points in last dimension
scatter_index
Indices describing where to scatter data
Returns
-------
Data scattered into tensor along scatter_index
"""
# scatter does not broadcast, so we need to manually broadcast the indices
broadcast_shape = torch.broadcast_shapes(scatter_index.shape[:-1], data_to_scatter.shape[:-1])
idx_expanded = torch.broadcast_to(scatter_index, (*broadcast_shape, scatter_index.shape[-1]))

# although scatter_ is inplace, this will not cause issues with autograd, as self
# is always constant zero and gradients w.r.t. src work as expected.
data_scattered = torch.zeros(
*broadcast_shape,
n_last_dim,
dtype=data_to_scatter.dtype,
device=data_to_scatter.device,
).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter)

return data_scattered
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz):
k_list = []
for spacing, nk in zip([type_kz, type_ky, type_kx], [nkz, nky, nkx], strict=True):
if spacing == 'non-uniform':
k = random_generator.float32_tensor(size=nk)
k = random_generator.float32_tensor(size=nk, low=-1, high=1) * max(nk)
elif spacing == 'uniform':
k = create_uniform_traj(nk, k_shape=k_shape)
elif spacing == 'zero':
Expand Down
25 changes: 25 additions & 0 deletions tests/operators/test_cartesian_sampling_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,28 @@ def test_cart_sampling_op_fwd_adj(sampling):
u = random_generator.complex64_tensor(size=k_shape)
v = random_generator.complex64_tensor(size=k_shape[:2] + trajectory.as_tensor().shape[2:])
dotproduct_adjointness_test(sampling_op, u, v)


@pytest.mark.parametrize(('k2_min', 'k2_max'), [(-1, 21), (-21, 1)])
@pytest.mark.parametrize(('k0_min', 'k0_max'), [(-6, 13), (-13, 6)])
def test_cart_sampling_op_oversampling(k0_min, k0_max, k2_min, k2_max):
"""Test trajectory points outside of encoding_matrix."""
encoding_matrix = SpatialDimension(40, 1, 20)

# Create kx and kz sampling which are asymmetric and larger than the encoding matrix on one side
# The indices are inverted to ensure CartesianSamplingOp acts on them
kx = rearrange(torch.linspace(k0_max, k0_min, 20), 'kx->1 1 1 kx')
ky = torch.ones(1, 1, 1, 1)
kz = rearrange(torch.linspace(k2_max, k2_min, 40), 'kz-> kz 1 1')
kz = torch.stack([kz, -kz], dim=0) # different kz values for two other elements
trajectory = KTrajectory(kz=kz, ky=ky, kx=kx)

with pytest.warns(UserWarning, match='K-space points lie outside of the encoding_matrix'):
sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory)

random_generator = RandomGenerator(seed=0)
u = random_generator.complex64_tensor(size=(3, 2, 5, kz.shape[-3], ky.shape[-2], kx.shape[-1]))
v = random_generator.complex64_tensor(size=(3, 2, 5, *encoding_matrix.zyx))

assert sampling_op.adjoint(u)[0].shape[-3:] == encoding_matrix.zyx
assert sampling_op(v)[0].shape[-3:] == (kz.shape[-3], ky.shape[-2], kx.shape[-1])
6 changes: 5 additions & 1 deletion tests/operators/test_fourier_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def test_fourier_op_not_supported_traj(im_shape, k_shape, nkx, nky, nkz, type_kx

# create operator
recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1])
encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1])
encoding_matrix = SpatialDimension(
int(trajectory.kz.max() - trajectory.kz.min() + 1),
int(trajectory.ky.max() - trajectory.ky.min() + 1),
int(trajectory.kx.max() - trajectory.kx.min() + 1),
)
with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'):
FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory)

Expand Down

3 comments on commit 202d395

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/mrpro/algorithms/csm
   inati.py24196%44
   walsh.py16194%34
src/mrpro/algorithms/dcf
   dcf_voronoi.py53492%15, 48–49, 76
src/mrpro/algorithms/optimizers
   adam.py20195%69
src/mrpro/algorithms/reconstruction
   DirectReconstruction.py281643%51–71, 85
   IterativeSENSEReconstruction.py13192%76
   Reconstruction.py502256%42, 54–56, 80–87, 104–113
   RegularizedIterativeSENSEReconstruction.py411759%96–100, 114–139
src/mrpro/data
   AcqInfo.py128398%26, 169, 207
   CsmData.py29390%15, 82–84
   DcfData.py45882%18, 66, 78–83
   IData.py67987%119, 125, 129, 159–167
   IHeader.py75791%75, 109, 127–131
   KHeader.py1531789%25, 119–123, 150, 199, 210, 217–218, 221, 228, 260–271
   KNoise.py311552%39–52, 56–61
   KTrajectory.py69593%178–182
   MoveDataMixin.py1371887%15, 113, 129, 143–145, 207, 305–307, 320, 399, 419–420, 422, 437–438, 440
   QData.py39782%42, 65–73
   Rotation.py6743595%100, 198, 335, 433, 477, 495, 581, 583, 592, 626, 628, 691, 768, 773, 776, 791, 808, 813, 889, 1077, 1082, 1085, 1109, 1113, 1240, 1242, 1250–1251, 1315, 1397, 1690, 1846, 1881, 1885, 1996
   SpatialDimension.py2302191%33, 103, 128, 135, 141, 261–263, 276–278, 312, 330, 343, 356, 369, 382, 391–392, 407, 416
   acq_filters.py12192%47
src/mrpro/data/_kdata
   KData.py1121884%108–109, 124, 131, 141, 149, 203–204, 242, 247–248, 267–278
   KDataRemoveOsMixin.py29293%44, 46
   KDataSelectMixin.py19289%48, 63
   KDataSplitMixin.py48394%53, 84, 93
src/mrpro/data/traj_calculators
   KTrajectoryCalculator.py25292%23, 45
   KTrajectoryIsmrmrd.py13285%41, 50
   KTrajectoryPulseq.py29197%54
src/mrpro/operators
   CartesianSamplingOp.py72297%118, 157
   ConstraintsOp.py60297%46, 48
   EndomorphOperator.py65297%228, 234
   FiniteDifferenceOp.py27293%40, 105
   Functional.py71593%20–22, 117, 119
   GridSamplingOp.py136993%72–73, 82–83, 90–91, 94, 96, 98
   LinearOperator.py1711293%55, 91, 190, 220, 261, 270, 278, 287, 295, 320, 418, 423
   LinearOperatorMatrix.py1581690%82, 119, 152, 161, 166, 175–178, 191–194, 203, 215, 304, 331, 359
   MultiIdentityOp.py13285%43, 48
   Operator.py78297%25, 74
   ProximableFunctionalSeparableSum.py39392%50, 103, 110
   SliceProjectionOp.py173895%44, 61, 63, 69, 206, 227, 260, 300
   WaveletOp.py120596%152, 170, 205, 210, 233
   ZeroPadOp.py16194%30
src/mrpro/utils
   filters.py62297%44, 49
   slice_profiles.py46687%20, 36, 113–116, 149
   sliding_window.py34197%34
   split_idx.py10280%43, 47
   summarize_tensorvalues.py11918%20–29
   typing.py181139%8–23
   zero_pad_or_crop.py31681%26, 30, 54, 57, 60, 63
TOTAL475435093% 

Tests Skipped Failures Errors Time
1971 0 💤 0 ❌ 0 🔥 1m 28s ⏱️

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/mrpro/algorithms/csm
   inati.py24196%44
   walsh.py16194%34
src/mrpro/algorithms/dcf
   dcf_voronoi.py53492%15, 48–49, 76
src/mrpro/algorithms/optimizers
   adam.py20195%69
src/mrpro/algorithms/reconstruction
   DirectReconstruction.py281643%51–71, 85
   IterativeSENSEReconstruction.py13192%76
   Reconstruction.py502256%42, 54–56, 80–87, 104–113
   RegularizedIterativeSENSEReconstruction.py411759%96–100, 114–139
src/mrpro/data
   AcqInfo.py128398%26, 169, 207
   CsmData.py29390%15, 82–84
   DcfData.py45882%18, 66, 78–83
   IData.py67987%119, 125, 129, 159–167
   IHeader.py75791%75, 109, 127–131
   KHeader.py1531789%25, 119–123, 150, 199, 210, 217–218, 221, 228, 260–271
   KNoise.py311552%39–52, 56–61
   KTrajectory.py69593%178–182
   MoveDataMixin.py1371887%15, 113, 129, 143–145, 207, 305–307, 320, 399, 419–420, 422, 437–438, 440
   QData.py39782%42, 65–73
   Rotation.py6743595%100, 198, 335, 433, 477, 495, 581, 583, 592, 626, 628, 691, 768, 773, 776, 791, 808, 813, 889, 1077, 1082, 1085, 1109, 1113, 1240, 1242, 1250–1251, 1315, 1397, 1690, 1846, 1881, 1885, 1996
   SpatialDimension.py2302191%33, 103, 128, 135, 141, 261–263, 276–278, 312, 330, 343, 356, 369, 382, 391–392, 407, 416
   acq_filters.py12192%47
src/mrpro/data/_kdata
   KData.py1121884%108–109, 124, 131, 141, 149, 203–204, 242, 247–248, 267–278
   KDataRemoveOsMixin.py29293%44, 46
   KDataSelectMixin.py19289%48, 63
   KDataSplitMixin.py48394%53, 84, 93
src/mrpro/data/traj_calculators
   KTrajectoryCalculator.py25292%23, 45
   KTrajectoryIsmrmrd.py13285%41, 50
   KTrajectoryPulseq.py29197%54
src/mrpro/operators
   CartesianSamplingOp.py72297%118, 157
   ConstraintsOp.py60297%46, 48
   EndomorphOperator.py65297%228, 234
   FiniteDifferenceOp.py27293%40, 105
   Functional.py71593%20–22, 117, 119
   GridSamplingOp.py136993%72–73, 82–83, 90–91, 94, 96, 98
   LinearOperator.py1711293%55, 91, 190, 220, 261, 270, 278, 287, 295, 320, 418, 423
   LinearOperatorMatrix.py1581690%82, 119, 152, 161, 166, 175–178, 191–194, 203, 215, 304, 331, 359
   MultiIdentityOp.py13285%43, 48
   Operator.py78297%25, 74
   ProximableFunctionalSeparableSum.py39392%50, 103, 110
   SliceProjectionOp.py173895%44, 61, 63, 69, 206, 227, 260, 300
   WaveletOp.py120596%152, 170, 205, 210, 233
   ZeroPadOp.py16194%30
src/mrpro/utils
   filters.py62297%44, 49
   slice_profiles.py46687%20, 36, 113–116, 149
   sliding_window.py34197%34
   split_idx.py10280%43, 47
   summarize_tensorvalues.py11918%20–29
   typing.py181139%8–23
   zero_pad_or_crop.py31681%26, 30, 54, 57, 60, 63
TOTAL475435093% 

Tests Skipped Failures Errors Time
1971 0 💤 0 ❌ 0 🔥 1m 30s ⏱️

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/mrpro/algorithms/csm
   inati.py24196%44
   walsh.py16194%34
src/mrpro/algorithms/dcf
   dcf_voronoi.py53492%15, 48–49, 76
src/mrpro/algorithms/optimizers
   adam.py20195%69
src/mrpro/algorithms/reconstruction
   DirectReconstruction.py281643%51–71, 85
   IterativeSENSEReconstruction.py13192%76
   Reconstruction.py502256%42, 54–56, 80–87, 104–113
   RegularizedIterativeSENSEReconstruction.py411759%96–100, 114–139
src/mrpro/data
   AcqInfo.py128398%26, 169, 207
   CsmData.py29390%15, 82–84
   DcfData.py45882%18, 66, 78–83
   IData.py67987%119, 125, 129, 159–167
   IHeader.py75791%75, 109, 127–131
   KHeader.py1531789%25, 119–123, 150, 199, 210, 217–218, 221, 228, 260–271
   KNoise.py311552%39–52, 56–61
   KTrajectory.py69593%178–182
   MoveDataMixin.py1371887%15, 113, 129, 143–145, 207, 305–307, 320, 399, 419–420, 422, 437–438, 440
   QData.py39782%42, 65–73
   Rotation.py6743595%100, 198, 335, 433, 477, 495, 581, 583, 592, 626, 628, 691, 768, 773, 776, 791, 808, 813, 889, 1077, 1082, 1085, 1109, 1113, 1240, 1242, 1250–1251, 1315, 1397, 1690, 1846, 1881, 1885, 1996
   SpatialDimension.py2302191%33, 103, 128, 135, 141, 261–263, 276–278, 312, 330, 343, 356, 369, 382, 391–392, 407, 416
   acq_filters.py12192%47
src/mrpro/data/_kdata
   KData.py1121884%108–109, 124, 131, 141, 149, 203–204, 242, 247–248, 267–278
   KDataRemoveOsMixin.py29293%44, 46
   KDataSelectMixin.py19289%48, 63
   KDataSplitMixin.py48394%53, 84, 93
src/mrpro/data/traj_calculators
   KTrajectoryCalculator.py25292%23, 45
   KTrajectoryIsmrmrd.py13285%41, 50
   KTrajectoryPulseq.py29197%54
src/mrpro/operators
   CartesianSamplingOp.py72297%118, 157
   ConstraintsOp.py60297%46, 48
   EndomorphOperator.py65297%228, 234
   FiniteDifferenceOp.py27293%40, 105
   Functional.py71593%20–22, 117, 119
   GridSamplingOp.py136993%72–73, 82–83, 90–91, 94, 96, 98
   LinearOperator.py1711293%55, 91, 190, 220, 261, 270, 278, 287, 295, 320, 418, 423
   LinearOperatorMatrix.py1581690%82, 119, 152, 161, 166, 175–178, 191–194, 203, 215, 304, 331, 359
   MultiIdentityOp.py13285%43, 48
   Operator.py78297%25, 74
   ProximableFunctionalSeparableSum.py39392%50, 103, 110
   SliceProjectionOp.py173895%44, 61, 63, 69, 206, 227, 260, 300
   WaveletOp.py120596%152, 170, 205, 210, 233
   ZeroPadOp.py16194%30
src/mrpro/utils
   filters.py62297%44, 49
   slice_profiles.py46687%20, 36, 113–116, 149
   sliding_window.py34197%34
   split_idx.py10280%43, 47
   summarize_tensorvalues.py11918%20–29
   typing.py181139%8–23
   zero_pad_or_crop.py31681%26, 30, 54, 57, 60, 63
TOTAL475435093% 

Tests Skipped Failures Errors Time
1971 0 💤 0 ❌ 0 🔥 1m 42s ⏱️

Please sign in to comment.