From d476a2467cdb78c81165a5fb7e313144ca4dc41e Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Sat, 28 Sep 2024 16:20:10 +0200 Subject: [PATCH] some more changes --- src/mrpro/algorithms/prewhiten_kspace.py | 2 +- src/mrpro/data/KTrajectoryRawShape.py | 6 ++-- src/mrpro/data/_kdata/KData.py | 6 ++-- src/mrpro/operators/CartesianSamplingOp.py | 12 ++++---- src/mrpro/operators/SensitivityOp.py | 8 +++--- tests/data/test_traj_calculators.py | 33 ++++++++++------------ 6 files changed, 32 insertions(+), 35 deletions(-) diff --git a/src/mrpro/algorithms/prewhiten_kspace.py b/src/mrpro/algorithms/prewhiten_kspace.py index 71f1ba56..231cac02 100644 --- a/src/mrpro/algorithms/prewhiten_kspace.py +++ b/src/mrpro/algorithms/prewhiten_kspace.py @@ -61,7 +61,7 @@ def prewhiten_kspace(kdata: KData, knoise: KNoise, scale_factor: float | torch.T # solve_triangular is numerically more stable than inverting the matrix # but requires a single batch dimension - kdata_flat = rearrange(kdata.data, '... coil k2 k1 k0 -> (... k2 k1 k0) coil 1') + kdata_flat = rearrange(kdata.data, '... coils k2 k1 k0 -> (... k2 k1 k0) coils 1') whitened_flat = torch.linalg.solve_triangular(cholsky, kdata_flat, upper=False) whitened_flatother = rearrange( whitened_flat, '(other k2 k1 k0) coil 1-> other coil k2 k1 k0', **parse_shape(kdata.data, '... k2 k1 k0') diff --git a/src/mrpro/data/KTrajectoryRawShape.py b/src/mrpro/data/KTrajectoryRawShape.py index 3730e666..98fbc217 100644 --- a/src/mrpro/data/KTrajectoryRawShape.py +++ b/src/mrpro/data/KTrajectoryRawShape.py @@ -56,8 +56,8 @@ def sort_and_reshape( KTrajectory with kx, ky and kz each in the shape (other k2 k1 k0). """ # Resort and reshape - kz = rearrange(self.kz[sort_idx, ...], '(other k2 k1) k0 -> other k2 k1 k0', k1=n_k1, k2=n_k2) - ky = rearrange(self.ky[sort_idx, ...], '(other k2 k1) k0 -> other k2 k1 k0', k1=n_k1, k2=n_k2) - kx = rearrange(self.kx[sort_idx, ...], '(other k2 k1) k0 -> other k2 k1 k0', k1=n_k1, k2=n_k2) + kz = rearrange(self.kz[sort_idx, ...], '... (other k2 k1) k0 -> ... other k2 k1 k0', k1=n_k1, k2=n_k2) + ky = rearrange(self.ky[sort_idx, ...], '... (other k2 k1) k0 -> ... other k2 k1 k0', k1=n_k1, k2=n_k2) + kx = rearrange(self.kx[sort_idx, ...], '... (other k2 k1) k0 -> ... other k2 k1 k0', k1=n_k1, k2=n_k2) return KTrajectory(kz, ky, kx, repeat_detection_tolerance=self.repeat_detection_tolerance) diff --git a/src/mrpro/data/_kdata/KData.py b/src/mrpro/data/_kdata/KData.py index 81b4f900..1dde0a17 100644 --- a/src/mrpro/data/_kdata/KData.py +++ b/src/mrpro/data/_kdata/KData.py @@ -68,10 +68,10 @@ class KData(KDataSplitMixin, KDataRearrangeMixin, KDataSelectMixin, KDataRemoveO """Header information for k-space data""" data: torch.Tensor - """K-space data. Shape (...other coils k2 k1 k0)""" + """K-space data. Shape (... other coils k2 k1 k0)""" traj: KTrajectory - """K-space trajectory along kz, ky and kx. Shape (...other k2 k1 k0)""" + """K-space trajectory along kz, ky and kx. Shape (... other k2 k1 k0)""" @classmethod def from_file( @@ -203,7 +203,7 @@ def sort_and_reshape_tensor_fields(input_tensor: torch.Tensor): return rearrange(input_tensor[sort_idx], '(other k2 k1) ... -> other k2 k1 ...', k1=n_k1, k2=n_k2) kheader.acq_info = modify_acq_info(sort_and_reshape_tensor_fields, kheader.acq_info) - kdata = rearrange(kdata[sort_idx], '(other k2 k1) coils k0 -> other coils k2 k1 k0', k1=n_k1, k2=n_k2) + kdata = rearrange(kdata[sort_idx], '... (other k2 k1) coils k0 -> ... other coils k2 k1 k0', k1=n_k1, k2=n_k2) # Calculate trajectory and check if it matches the kdata shape match ktrajectory: diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index d1205726..5b6e8b76 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -47,23 +47,23 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) -> kx_idx = ktraj_tensor[-1, ...].round().to(dtype=torch.int64) + sorted_grid_shape.x // 2 else: sorted_grid_shape.x = ktraj_tensor.shape[-1] - kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0->other k1 k2 k0', other=1, k2=1, k1=1) + kx_idx = repeat(torch.arange(ktraj_tensor.shape[-1]), 'k0 -> k1 k2 k0', k2=1, k1=1) if traj_type_kzyx[-2] == TrajType.ONGRID: # ky ky_idx = ktraj_tensor[-2, ...].round().to(dtype=torch.int64) + sorted_grid_shape.y // 2 else: sorted_grid_shape.y = ktraj_tensor.shape[-2] - ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1->other k1 k2 k0', other=1, k2=1, k0=1) + ky_idx = repeat(torch.arange(ktraj_tensor.shape[-2]), 'k1 -> k1 k2 k0', k2=1, k0=1) if traj_type_kzyx[-3] == TrajType.ONGRID: # kz kz_idx = ktraj_tensor[-3, ...].round().to(dtype=torch.int64) + sorted_grid_shape.z // 2 else: sorted_grid_shape.z = ktraj_tensor.shape[-3] - kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2->other k1 k2 k0', other=1, k1=1, k0=1) + kz_idx = repeat(torch.arange(ktraj_tensor.shape[-3]), 'k2 -> k1 k2 k0', k1=1, k0=1) # 1D indices into a flattened tensor. kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx - kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)') + kidx = rearrange(kidx, '... k2 k1 k0 -> ... 1 (k2 k1 k0)') self.register_buffer('_fft_idx', kidx) # we can skip the indexing if the data is already sorted self._needs_indexing = not torch.all(torch.diff(kidx) == 1) @@ -90,7 +90,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: if not self._needs_indexing: return (x,) - x_kflat = rearrange(x, '... coils k2_enc k1_enc k0_enc -> ... coils (k2_enc k1_enc k0_enc)') + x_kflat = rearrange(x, '... k2_enc k1_enc k0_enc -> ... (k2_enc k1_enc k0_enc)') # take_along_dim does broadcast, so no need for extending here x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1) # reshape to (... other coils, k2, k1, k0) @@ -116,7 +116,7 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: if not self._needs_indexing: return (y,) - y_kflat = rearrange(y, '... coils k2 k1 k0 -> ... coils (k2 k1 k0)') + y_kflat = rearrange(y, '... k2 k1 k0 -> ... (k2 k1 k0)') # scatter does not broadcast, so we need to manually broadcast the indices broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1]) diff --git a/src/mrpro/operators/SensitivityOp.py b/src/mrpro/operators/SensitivityOp.py index 56294881..365fcf6e 100644 --- a/src/mrpro/operators/SensitivityOp.py +++ b/src/mrpro/operators/SensitivityOp.py @@ -29,11 +29,11 @@ def forward(self, img: torch.Tensor) -> tuple[torch.Tensor,]: Parameters ---------- img - image data tensor with dimensions (other 1 z y x). + image data tensor with dimensions (... other 1 z y x). Returns ------- - image data tensor with dimensions (other coils z y x). + image data tensor with dimensions (... other coils z y x). """ return (self.csm_tensor * img,) @@ -43,10 +43,10 @@ def adjoint(self, img: torch.Tensor) -> tuple[torch.Tensor,]: Parameters ---------- img - image data tensor with dimensions (other coils z y x). + image data tensor with dimensions (... other coils z y x). Returns ------- - image data tensor with dimensions (other 1 z y x). + image data tensor with dimensions (... other 1 z y x). """ return ((self.csm_tensor.conj() * img).sum(-4, keepdim=True),) diff --git a/tests/data/test_traj_calculators.py b/tests/data/test_traj_calculators.py index 64db2087..755731be 100644 --- a/tests/data/test_traj_calculators.py +++ b/tests/data/test_traj_calculators.py @@ -3,7 +3,6 @@ import numpy as np import pytest import torch -from einops import repeat from mrpro.data import KData from mrpro.data.enums import AcqFlags from mrpro.data.traj_calculators import ( @@ -25,7 +24,7 @@ def valid_rad2d_kheader(monkeypatch, random_kheader): n_k0 = 256 n_k1 = 10 n_k2 = 1 - n_other = (2,1) + n_other = (2, 1) # List of k1 indices in the shape (*other, 1, k1) idx_k1 = torch.arange(n_k1, dtype=torch.int32).repeat(*n_other, 1, 1) @@ -49,7 +48,7 @@ def radial2D_traj_shape(valid_rad2d_kheader): n_k0 = valid_rad2d_kheader.acq_info.number_of_samples[0, 0, 0, 0] n_k1 = valid_rad2d_kheader.acq_info.idx.k1.shape[-1] n_k2 = 1 - n_other = (1,1) + n_other = (1, 1) return ( torch.Size([*n_other, 1, 1, 1]), torch.Size([*n_other, n_k2, n_k1, n_k0]), @@ -74,8 +73,8 @@ def valid_rpe_kheader(monkeypatch, random_kheader): n_k0 = 200 n_k1 = 20 n_k2 = 10 - - n_other= (1,1) + + n_other = (1, 1) # List of k1 and k2 indices in the shape (other, k2, k1) k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32) @@ -102,7 +101,7 @@ def rpe_traj_shape(valid_rpe_kheader): n_k0 = valid_rpe_kheader.acq_info.number_of_samples[0, 0, 0, 0] n_k1 = valid_rpe_kheader.acq_info.idx.k1.shape[-1] n_k2 = valid_rpe_kheader.acq_info.idx.k1.shape[-2] - n_other = (1,1) + n_other = (1, 1) return ( torch.Size([*n_other, n_k2, n_k1, 1]), torch.Size([*n_other, n_k2, n_k1, 1]), @@ -131,11 +130,10 @@ def test_KTrajectoryRpe_uniform(valid_rpe_kheader): shift_between_rpe_lines=torch.tensor([0]), ) trajectory2 = trajectory2_calculator(valid_rpe_kheader) - - - torch.testing.assert_close(trajectory1.kx[:,:, : n_rpe_lines // 2, :, :], trajectory2.kx[:,:, ::2, :, :]) - torch.testing.assert_close(trajectory1.ky[:,:, : n_rpe_lines // 2, :, :], trajectory2.ky[:,:, ::2, :, :]) - torch.testing.assert_close(trajectory1.kz[:,:, : n_rpe_lines // 2, :, :], trajectory2.kz[:,:, ::2, :, :]) + + torch.testing.assert_close(trajectory1.kx[:, :, : n_rpe_lines // 2, :, :], trajectory2.kx[:, :, ::2, :, :]) + torch.testing.assert_close(trajectory1.ky[:, :, : n_rpe_lines // 2, :, :], trajectory2.ky[:, :, ::2, :, :]) + torch.testing.assert_close(trajectory1.kz[:, :, : n_rpe_lines // 2, :, :], trajectory2.kz[:, :, ::2, :, :]) def test_KTrajectoryRpe_shift(valid_rpe_kheader): @@ -164,16 +162,15 @@ def valid_cartesian_kheader(monkeypatch, random_kheader): n_k0 = 200 n_k1 = 20 n_k2 = 10 - n_other = (2,2) + n_other = (2, 2) # List of k1 and k2 indices in the shape (other, k2, k1) k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32) k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32) idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy') - idx_k1 = idx_k1.reshape(n_k2, n_k1)).repeat(*n_other, 1, 1) - idx_k2 = torch.reshape(n_k2, n_k1)).repeat(*n_other, 1, 1) - - + idx_k1 = idx_k1.reshape(n_k2, n_k1).repeat(*n_other, 1, 1) + idx_k2 = idx_k2.reshape(n_k2, n_k1).repeat(*n_other, 1, 1) + # Set parameters for trajectory (AcqInfo is of shape (*other k2 k1 dim=1 or 3)) monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2) @@ -189,10 +186,10 @@ def valid_cartesian_kheader(monkeypatch, random_kheader): def cartesian_traj_shape(valid_cartesian_kheader): """Expected shape of trajectory based on KHeader.""" - n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0,0,0,0] + n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0, 0, 0, 0] n_k1 = valid_cartesian_kheader.acq_info.idx.k1.shape[-1] n_k2 = valid_cartesian_kheader.acq_info.idx.k1.shape[-2] - n_other = (1,1) # trajectory along other is the same + n_other = (1, 1) # trajectory along other is the same return (torch.Size([*n_other, n_k2, 1, 1]), torch.Size([*n_other, 1, n_k1, 1]), torch.Size([*n_other, 1, 1, n_k0]))