Skip to content

Commit

Permalink
some more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 committed Sep 28, 2024
1 parent 20f5460 commit d476a24
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/mrpro/algorithms/prewhiten_kspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def prewhiten_kspace(kdata: KData, knoise: KNoise, scale_factor: float | torch.T

# solve_triangular is numerically more stable than inverting the matrix
# but requires a single batch dimension
kdata_flat = rearrange(kdata.data, '... coil k2 k1 k0 -> (... k2 k1 k0) coil 1')
kdata_flat = rearrange(kdata.data, '... coils k2 k1 k0 -> (... k2 k1 k0) coils 1')
whitened_flat = torch.linalg.solve_triangular(cholsky, kdata_flat, upper=False)
whitened_flatother = rearrange(
whitened_flat, '(other k2 k1 k0) coil 1-> other coil k2 k1 k0', **parse_shape(kdata.data, '... k2 k1 k0')
Expand Down
6 changes: 3 additions & 3 deletions src/mrpro/data/KTrajectoryRawShape.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def sort_and_reshape(
KTrajectory with kx, ky and kz each in the shape (other k2 k1 k0).
"""
# Resort and reshape
kz = rearrange(self.kz[sort_idx, ...], '(other k2 k1) k0 -> other k2 k1 k0', k1=n_k1, k2=n_k2)
ky = rearrange(self.ky[sort_idx, ...], '(other k2 k1) k0 -> other k2 k1 k0', k1=n_k1, k2=n_k2)
kx = rearrange(self.kx[sort_idx, ...], '(other k2 k1) k0 -> other k2 k1 k0', k1=n_k1, k2=n_k2)
kz = rearrange(self.kz[sort_idx, ...], '... (other k2 k1) k0 -> ... other k2 k1 k0', k1=n_k1, k2=n_k2)
ky = rearrange(self.ky[sort_idx, ...], '... (other k2 k1) k0 -> ... other k2 k1 k0', k1=n_k1, k2=n_k2)
kx = rearrange(self.kx[sort_idx, ...], '... (other k2 k1) k0 -> ... other k2 k1 k0', k1=n_k1, k2=n_k2)

return KTrajectory(kz, ky, kx, repeat_detection_tolerance=self.repeat_detection_tolerance)
6 changes: 3 additions & 3 deletions src/mrpro/data/_kdata/KData.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ class KData(KDataSplitMixin, KDataRearrangeMixin, KDataSelectMixin, KDataRemoveO
"""Header information for k-space data"""

data: torch.Tensor
"""K-space data. Shape (...other coils k2 k1 k0)"""
"""K-space data. Shape (... other coils k2 k1 k0)"""

traj: KTrajectory
"""K-space trajectory along kz, ky and kx. Shape (...other k2 k1 k0)"""
"""K-space trajectory along kz, ky and kx. Shape (... other k2 k1 k0)"""

@classmethod
def from_file(
Expand Down Expand Up @@ -203,7 +203,7 @@ def sort_and_reshape_tensor_fields(input_tensor: torch.Tensor):
return rearrange(input_tensor[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2)

kheader.acq_info = modify_acq_info(sort_and_reshape_tensor_fields, kheader.acq_info)
kdata = rearrange(kdata[sort_idx], '(other k2 k1) coils k0 -> other coils k2 k1 k0', k1=n_k1, k2=n_k2)
kdata = rearrange(kdata[sort_idx], '... (other k2 k1) coils k0 -> ... other coils k2 k1 k0', k1=n_k1, k2=n_k2)

# Calculate trajectory and check if it matches the kdata shape
match ktrajectory:
Expand Down
12 changes: 6 additions & 6 deletions src/mrpro/operators/CartesianSamplingOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,23 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) ->
kx_idx = ktraj_tensor[-1, ...].round().to(dtype=torch.int64) + sorted_grid_shape.x // 2
else:
sorted_grid_shape.x = ktraj_tensor.shape[-1]
kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k1 k2 k0', other=1, k2=1, k1=1)
kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0 -> k1 k2 k0', k2=1, k1=1)

if traj_type_kzyx[-2] == TrajType.ONGRID: # ky
ky_idx = ktraj_tensor[-2, ...].round().to(dtype=torch.int64) + sorted_grid_shape.y // 2
else:
sorted_grid_shape.y = ktraj_tensor.shape[-2]
ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k1 k2 k0', other=1, k2=1, k0=1)
ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1 -> k1 k2 k0', k2=1, k0=1)

if traj_type_kzyx[-3] == TrajType.ONGRID: # kz
kz_idx = ktraj_tensor[-3, ...].round().to(dtype=torch.int64) + sorted_grid_shape.z // 2
else:
sorted_grid_shape.z = ktraj_tensor.shape[-3]
kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k1 k2 k0', other=1, k1=1, k0=1)
kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2 -> k1 k2 k0', k1=1, k0=1)

# 1D indices into a flattened tensor.
kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx
kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)')
kidx = rearrange(kidx, '... k2 k1 k0 -> ... 1 (k2 k1 k0)')
self.register_buffer('_fft_idx', kidx)
# we can skip the indexing if the data is already sorted
self._needs_indexing = not torch.all(torch.diff(kidx) == 1)
Expand All @@ -90,7 +90,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
if not self._needs_indexing:
return (x,)

x_kflat = rearrange(x, '... coils k2_enc k1_enc k0_enc -> ... coils (k2_enc k1_enc k0_enc)')
x_kflat = rearrange(x, '... k2_enc k1_enc k0_enc -> ... (k2_enc k1_enc k0_enc)')
# take_along_dim does broadcast, so no need for extending here
x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1)
# reshape to (... other coils, k2, k1, k0)
Expand All @@ -116,7 +116,7 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
if not self._needs_indexing:
return (y,)

y_kflat = rearrange(y, '... coils k2 k1 k0 -> ... coils (k2 k1 k0)')
y_kflat = rearrange(y, '... k2 k1 k0 -> ... (k2 k1 k0)')

# scatter does not broadcast, so we need to manually broadcast the indices
broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1])
Expand Down
8 changes: 4 additions & 4 deletions src/mrpro/operators/SensitivityOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def forward(self, img: torch.Tensor) -> tuple[torch.Tensor,]:
Parameters
----------
img
image data tensor with dimensions (other 1 z y x).
image data tensor with dimensions (... other 1 z y x).
Returns
-------
image data tensor with dimensions (other coils z y x).
image data tensor with dimensions (... other coils z y x).
"""
return (self.csm_tensor * img,)

Expand All @@ -43,10 +43,10 @@ def adjoint(self, img: torch.Tensor) -> tuple[torch.Tensor,]:
Parameters
----------
img
image data tensor with dimensions (other coils z y x).
image data tensor with dimensions (... other coils z y x).
Returns
-------
image data tensor with dimensions (other 1 z y x).
image data tensor with dimensions (... other 1 z y x).
"""
return ((self.csm_tensor.conj() * img).sum(-4, keepdim=True),)
33 changes: 15 additions & 18 deletions tests/data/test_traj_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import pytest
import torch
from einops import repeat
from mrpro.data import KData
from mrpro.data.enums import AcqFlags
from mrpro.data.traj_calculators import (
Expand All @@ -25,7 +24,7 @@ def valid_rad2d_kheader(monkeypatch, random_kheader):
n_k0 = 256
n_k1 = 10
n_k2 = 1
n_other = (2,1)
n_other = (2, 1)

# List of k1 indices in the shape (*other, 1, k1)
idx_k1 = torch.arange(n_k1, dtype=torch.int32).repeat(*n_other, 1, 1)
Expand All @@ -49,7 +48,7 @@ def radial2D_traj_shape(valid_rad2d_kheader):
n_k0 = valid_rad2d_kheader.acq_info.number_of_samples[0, 0, 0, 0]
n_k1 = valid_rad2d_kheader.acq_info.idx.k1.shape[-1]
n_k2 = 1
n_other = (1,1)
n_other = (1, 1)
return (
torch.Size([*n_other, 1, 1, 1]),
torch.Size([*n_other, n_k2, n_k1, n_k0]),
Expand All @@ -74,8 +73,8 @@ def valid_rpe_kheader(monkeypatch, random_kheader):
n_k0 = 200
n_k1 = 20
n_k2 = 10
n_other= (1,1)

n_other = (1, 1)

# List of k1 and k2 indices in the shape (other, k2, k1)
k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32)
Expand All @@ -102,7 +101,7 @@ def rpe_traj_shape(valid_rpe_kheader):
n_k0 = valid_rpe_kheader.acq_info.number_of_samples[0, 0, 0, 0]
n_k1 = valid_rpe_kheader.acq_info.idx.k1.shape[-1]
n_k2 = valid_rpe_kheader.acq_info.idx.k1.shape[-2]
n_other = (1,1)
n_other = (1, 1)
return (
torch.Size([*n_other, n_k2, n_k1, 1]),
torch.Size([*n_other, n_k2, n_k1, 1]),
Expand Down Expand Up @@ -131,11 +130,10 @@ def test_KTrajectoryRpe_uniform(valid_rpe_kheader):
shift_between_rpe_lines=torch.tensor([0]),
)
trajectory2 = trajectory2_calculator(valid_rpe_kheader)


torch.testing.assert_close(trajectory1.kx[:,:, : n_rpe_lines // 2, :, :], trajectory2.kx[:,:, ::2, :, :])
torch.testing.assert_close(trajectory1.ky[:,:, : n_rpe_lines // 2, :, :], trajectory2.ky[:,:, ::2, :, :])
torch.testing.assert_close(trajectory1.kz[:,:, : n_rpe_lines // 2, :, :], trajectory2.kz[:,:, ::2, :, :])

torch.testing.assert_close(trajectory1.kx[:, :, : n_rpe_lines // 2, :, :], trajectory2.kx[:, :, ::2, :, :])
torch.testing.assert_close(trajectory1.ky[:, :, : n_rpe_lines // 2, :, :], trajectory2.ky[:, :, ::2, :, :])
torch.testing.assert_close(trajectory1.kz[:, :, : n_rpe_lines // 2, :, :], trajectory2.kz[:, :, ::2, :, :])


def test_KTrajectoryRpe_shift(valid_rpe_kheader):
Expand Down Expand Up @@ -164,16 +162,15 @@ def valid_cartesian_kheader(monkeypatch, random_kheader):
n_k0 = 200
n_k1 = 20
n_k2 = 10
n_other = (2,2)
n_other = (2, 2)

# List of k1 and k2 indices in the shape (other, k2, k1)
k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32)
k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32)
idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy')
idx_k1 = idx_k1.reshape(n_k2, n_k1)).repeat(*n_other, 1, 1)
idx_k2 = torch.reshape(n_k2, n_k1)).repeat(*n_other, 1, 1)


idx_k1 = idx_k1.reshape(n_k2, n_k1).repeat(*n_other, 1, 1)
idx_k2 = idx_k2.reshape(n_k2, n_k1).repeat(*n_other, 1, 1)

# Set parameters for trajectory (AcqInfo is of shape (*other k2 k1 dim=1 or 3))
monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0)
monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2)
Expand All @@ -189,10 +186,10 @@ def valid_cartesian_kheader(monkeypatch, random_kheader):

def cartesian_traj_shape(valid_cartesian_kheader):
"""Expected shape of trajectory based on KHeader."""
n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0,0,0,0]
n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0, 0, 0, 0]
n_k1 = valid_cartesian_kheader.acq_info.idx.k1.shape[-1]
n_k2 = valid_cartesian_kheader.acq_info.idx.k1.shape[-2]
n_other = (1,1) # trajectory along other is the same
n_other = (1, 1) # trajectory along other is the same
return (torch.Size([*n_other, n_k2, 1, 1]), torch.Size([*n_other, 1, n_k1, 1]), torch.Size([*n_other, 1, 1, n_k0]))


Expand Down

0 comments on commit d476a24

Please sign in to comment.