Skip to content

Commit

Permalink
Fix CartesianSamplingOp (#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
ckolbPTB authored Nov 9, 2024
1 parent c9286b8 commit 6f2378c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
10 changes: 6 additions & 4 deletions src/mrpro/operators/CartesianSamplingOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,28 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) ->
kx_idx = ktraj_tensor[-1, ...].round().to(dtype=torch.int64) + sorted_grid_shape.x // 2
else:
sorted_grid_shape.x = ktraj_tensor.shape[-1]
kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k1 k2 k0', other=1, k2=1, k1=1)
kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k2 k1 k0', other=1, k2=1, k1=1)

if traj_type_kzyx[-2] == TrajType.ONGRID: # ky
ky_idx = ktraj_tensor[-2, ...].round().to(dtype=torch.int64) + sorted_grid_shape.y // 2
else:
sorted_grid_shape.y = ktraj_tensor.shape[-2]
ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k1 k2 k0', other=1, k2=1, k0=1)
ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k2 k1 k0', other=1, k2=1, k0=1)

if traj_type_kzyx[-3] == TrajType.ONGRID: # kz
kz_idx = ktraj_tensor[-3, ...].round().to(dtype=torch.int64) + sorted_grid_shape.z // 2
else:
sorted_grid_shape.z = ktraj_tensor.shape[-3]
kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k1 k2 k0', other=1, k1=1, k0=1)
kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k2 k1 k0', other=1, k1=1, k0=1)

# 1D indices into a flattened tensor.
kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx
kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)')
self.register_buffer('_fft_idx', kidx)
# we can skip the indexing if the data is already sorted
self._needs_indexing = not torch.all(torch.diff(kidx) == 1)
self._needs_indexing = (
not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx
)

self._trajectory_shape = traj.broadcasted_shape
self._sorted_grid_shape = sorted_grid_shape
Expand Down
17 changes: 15 additions & 2 deletions tests/operators/test_cartesian_sampling_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
from einops import rearrange
from mrpro.data import KTrajectory, SpatialDimension
from mrpro.operators import CartesianSamplingOp

Expand Down Expand Up @@ -59,6 +60,9 @@ def test_cart_sampling_op_data_match():
'regular_undersampling',
'random_undersampling',
'different_random_undersampling',
'cartesian_and_non_cartesian',
'kx_ky_along_k0',
'kx_ky_along_k0_undersampling',
],
)
def test_cart_sampling_op_fwd_adj(sampling):
Expand All @@ -70,8 +74,8 @@ def test_cart_sampling_op_fwd_adj(sampling):
nky = (2, 1, 40, 1)
nkz = (2, 20, 1, 1)
sx = 'uf'
sy = 'uf'
sz = 'uf'
sy = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf'
sz = 'nuf' if sampling == 'cartesian_and_non_cartesian' else 'uf'
trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, sx, sy, sz).as_tensor()

# Subsample data and trajectory
Expand All @@ -94,6 +98,15 @@ def test_cart_sampling_op_fwd_adj(sampling):
for traj_one_other in trajectory_tensor.unbind(1)
]
trajectory = KTrajectory.from_tensor(torch.stack(traj_list, dim=1))
case 'cartesian_and_non_cartesian':
trajectory = KTrajectory.from_tensor(trajectory_tensor)
case 'kx_ky_along_k0':
trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)')
trajectory = KTrajectory.from_tensor(trajectory_tensor)
case 'kx_ky_along_k0_undersampling':
trajectory_tensor = rearrange(trajectory_tensor, '... k1 k0->... 1 (k1 k0)')
random_idx = torch.randperm(trajectory_tensor.shape[-1])
trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]])
case _:
raise NotImplementedError(f'Test {sampling} not implemented.')

Expand Down

0 comments on commit 6f2378c

Please sign in to comment.