From d014b72b075a6e1b4681a1a9ba51486425b1fe4c Mon Sep 17 00:00:00 2001 From: guastinimara Date: Wed, 10 Jul 2024 12:57:07 +0200 Subject: [PATCH] some inital changes on other dimension --- src/mrpro/data/IData.py | 18 +++-- src/mrpro/data/KNoise.py | 2 +- .../traj_calculators/KTrajectoryCartesian.py | 4 +- .../traj_calculators/KTrajectoryRadial2D.py | 2 +- src/mrpro/operators/CartesianSamplingOp.py | 8 +- src/mrpro/operators/DensityCompensationOp.py | 2 +- tests/data/test_traj_calculators.py | 73 +++++++++++-------- 7 files changed, 61 insertions(+), 48 deletions(-) diff --git a/src/mrpro/data/IData.py b/src/mrpro/data/IData.py index 5d9cd6af..f2d16600 100644 --- a/src/mrpro/data/IData.py +++ b/src/mrpro/data/IData.py @@ -110,7 +110,7 @@ def from_single_dicom(cls, filename: str | Path) -> IData: dataset = dcmread(filename) idata = _dcm_pixelarray_to_tensor(dataset)[None, :] idata = rearrange(idata, '(other coils z) y x -> other coils z y x', other=1, coils=1, z=1) - + header = IHeader.from_dicom_list([dataset]) return cls(data=idata, header=header) @@ -146,13 +146,15 @@ def get_unique_slice_positions(slice_pos_tag: TagType = 0x00191015): raise ValueError('Only dicoms with the same orientation can be read in.') # stack required due to mypy: einops rearrange list[tensor]->tensor not recognized idata = torch.stack([_dcm_pixelarray_to_tensor(ds) for ds in dataset_list]) - idata = rearrange( - idata, - '(other coils z) y x -> other coils z y x', - other=len(idata), - coils=1, - z=1, - ) + # idata = rearrange( + # idata, + # '(other coils z) y x -> other coils z y x', + # other=len(idata), + # coils=1, + # z=1, + # ) + + idata = idata.reshape(*idata.shape[:-2],1,1,idata.shape[-2],idata.shape[-1]) header = IHeader.from_dicom_list(dataset_list) return cls(data=idata, header=header) diff --git a/src/mrpro/data/KNoise.py b/src/mrpro/data/KNoise.py index 3a91f2cd..dd8cda6e 100644 --- a/src/mrpro/data/KNoise.py +++ b/src/mrpro/data/KNoise.py @@ -67,6 +67,6 @@ def from_file( noise_data = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions]) # Reshape to standard dimensions - noise_data = rearrange(noise_data, 'other coils (k2 k1 k0)->other coils k2 k1 k0', k1=1, k2=1) + noise_data = rearrange(noise_data, '... coils (k2 k1 k0)-> ... coils k2 k1 k0', k1=1, k2=1) return cls(noise_data) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py index 463eac66..96ad39c3 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py @@ -45,6 +45,6 @@ def __call__(self, kheader: KHeader) -> KTrajectory: kz = (kheader.acq_info.idx.k2 - kheader.encoding_limits.k2.center).to(torch.float32) # Bring to correct dimensions - ky = repeat(ky, 'other k2 k1-> other k2 k1 k0', k0=1) - kz = repeat(kz, 'other k2 k1-> other k2 k1 k0', k0=1) + ky = repeat(ky, '... k2 k1-> ... k2 k1 k0', k0=1) + kz = repeat(kz, '... k2 k1-> ... k2 k1 k0', k0=1) return KTrajectory(kz, ky, kx) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py b/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py index 13b832d8..6c9d5364 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py @@ -56,6 +56,6 @@ def __call__(self, kheader: KHeader) -> KTrajectory: # K-space cartesian coordinates kx = krad * torch.cos(kang)[..., None] ky = krad * torch.sin(kang)[..., None] - kz = torch.zeros(1, 1, 1, 1) + kz = torch.zeros(kx.dim() * (1,)) return KTrajectory(kz, ky, kx) diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 239f8f88..7882ca9d 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -104,10 +104,10 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: if not self._needs_indexing: return (x,) - x_kflat = rearrange(x, '... coil k2_enc k1_enc k0_enc -> ... coil (k2_enc k1_enc k0_enc)') + x_kflat = rearrange(x, '... coils k2_enc k1_enc k0_enc -> ... coils (k2_enc k1_enc k0_enc)') # take_along_dim does broadcast, so no need for extending here x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1) - # reshape to (... other coil, k2, k1, k0) + # reshape to (... other coils, k2, k1, k0) x_reshaped = x_indexed.reshape(x.shape[:-3] + self._trajectory_shape[-3:]) return (x_reshaped,) @@ -130,7 +130,7 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: if not self._needs_indexing: return (y,) - y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)') + y_kflat = rearrange(y, '... coils k2 k1 k0 -> ... coils (k2 k1 k0)') # scatter does not broadcast, so we need to manually broadcast the indices broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1]) @@ -145,7 +145,7 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: device=y.device, ).scatter_(dim=-1, index=idx_expanded, src=y_kflat) - # reshape to ..., other, coil, k2_enc, k1_enc, k0_enc + # reshape to ..., other, coils, k2_enc, k1_enc, k0_enc y_reshaped = y_scattered.reshape( *y.shape[:-3], self._sorted_grid_shape.z, diff --git a/src/mrpro/operators/DensityCompensationOp.py b/src/mrpro/operators/DensityCompensationOp.py index ee681dab..5cd9487e 100644 --- a/src/mrpro/operators/DensityCompensationOp.py +++ b/src/mrpro/operators/DensityCompensationOp.py @@ -40,4 +40,4 @@ def __init__(self, dcf: DcfData | torch.Tensor) -> None: dcf_tensor = dcf.data else: dcf_tensor = dcf - super().__init__(dcf_tensor, '... k2 k1 k0 ,... coil k2 k1 k0 ->... coil k2 k1 k0') + super().__init__(dcf_tensor, '... k2 k1 k0 ,... coils k2 k1 k0 ->... coils k2 k1 k0') diff --git a/tests/data/test_traj_calculators.py b/tests/data/test_traj_calculators.py index 80d32403..7892734d 100644 --- a/tests/data/test_traj_calculators.py +++ b/tests/data/test_traj_calculators.py @@ -36,9 +36,10 @@ def valid_rad2d_kheader(monkeypatch, random_kheader): n_k0 = 256 n_k1 = 10 n_k2 = 1 + n_other = (1,1) - # List of k1 indices in the shape - idx_k1 = torch.arange(n_k1, dtype=torch.int32)[None, None, ...] + # List of k1 indices in the shape (other, 1, k1) + idx_k1 = torch.arange(n_k1, dtype=torch.int32).repeat(*n_other, 1, 1) # Set parameters for radial 2D trajectory monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) @@ -56,14 +57,14 @@ def valid_rad2d_kheader(monkeypatch, random_kheader): def radial2D_traj_shape(valid_rad2d_kheader): """Expected shape of trajectory based on KHeader.""" - n_k0 = valid_rad2d_kheader.acq_info.number_of_samples[0, 0, 0] - n_k1 = valid_rad2d_kheader.acq_info.idx.k1.shape[2] + n_k0 = valid_rad2d_kheader.acq_info.number_of_samples[0, 0, 0, 0] + n_k1 = valid_rad2d_kheader.acq_info.idx.k1.shape[-1] n_k2 = 1 - n_other = 1 + n_other = (1,1) return ( - torch.Size([n_other, 1, 1, 1]), - torch.Size([n_other, n_k2, n_k1, n_k0]), - torch.Size([n_other, n_k2, n_k1, n_k0]), + torch.Size([*n_other, 1, 1, 1]), + torch.Size([*n_other, n_k2, n_k1, n_k0]), + torch.Size([*n_other, n_k2, n_k1, n_k0]), ) @@ -84,13 +85,17 @@ def valid_rpe_kheader(monkeypatch, random_kheader): n_k0 = 200 n_k1 = 20 n_k2 = 10 + + n_other= (1,1) # List of k1 and k2 indices in the shape (other, k2, k1) k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32) k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32) idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy') - idx_k1 = torch.reshape(idx_k1, (1, n_k2, n_k1)) - idx_k2 = torch.reshape(idx_k2, (1, n_k2, n_k1)) + idx_k1 = torch.reshape(idx_k1, (n_k2, n_k1)) + idx_k1 = idx_k1.repeat(*n_other, 1, 1) + idx_k2 = torch.reshape(idx_k2, (n_k2, n_k1)) + idx_k2 = idx_k2.repeat(*n_other, 1, 1) # Set parameters for RPE trajectory monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) @@ -105,14 +110,14 @@ def valid_rpe_kheader(monkeypatch, random_kheader): def rpe_traj_shape(valid_rpe_kheader): """Expected shape of trajectory based on KHeader.""" - n_k0 = valid_rpe_kheader.acq_info.number_of_samples[0, 0, 0] - n_k1 = valid_rpe_kheader.acq_info.idx.k1.shape[2] - n_k2 = valid_rpe_kheader.acq_info.idx.k1.shape[1] - n_other = 1 + n_k0 = valid_rpe_kheader.acq_info.number_of_samples[0, 0, 0, 0] + n_k1 = valid_rpe_kheader.acq_info.idx.k1.shape[-1] + n_k2 = valid_rpe_kheader.acq_info.idx.k1.shape[-2] + n_other = (1,1) return ( - torch.Size([n_other, n_k2, n_k1, 1]), - torch.Size([n_other, n_k2, n_k1, 1]), - torch.Size([n_other, 1, 1, n_k0]), + torch.Size([*n_other, n_k2, n_k1, 1]), + torch.Size([*n_other, n_k2, n_k1, 1]), + torch.Size([*n_other, 1, 1, n_k0]), ) @@ -128,7 +133,7 @@ def test_KTrajectoryRpe_golden(valid_rpe_kheader): def test_KTrajectoryRpe_uniform(valid_rpe_kheader): """Calculate RPE trajectory with uniform angle.""" - n_rpe_lines = valid_rpe_kheader.acq_info.idx.k1.shape[1] + n_rpe_lines = valid_rpe_kheader.acq_info.idx.k1.shape[-2] trajectory1_calculator = KTrajectoryRpe(angle=torch.pi / n_rpe_lines, shift_between_rpe_lines=torch.tensor([0])) trajectory1 = trajectory1_calculator(valid_rpe_kheader) # Calculate trajectory with half the angular gap such that every second line should be the same as above @@ -137,10 +142,11 @@ def test_KTrajectoryRpe_uniform(valid_rpe_kheader): shift_between_rpe_lines=torch.tensor([0]), ) trajectory2 = trajectory2_calculator(valid_rpe_kheader) - - torch.testing.assert_close(trajectory1.kx[:, : n_rpe_lines // 2, :, :], trajectory2.kx[:, ::2, :, :]) - torch.testing.assert_close(trajectory1.ky[:, : n_rpe_lines // 2, :, :], trajectory2.ky[:, ::2, :, :]) - torch.testing.assert_close(trajectory1.kz[:, : n_rpe_lines // 2, :, :], trajectory2.kz[:, ::2, :, :]) + + + torch.testing.assert_close(trajectory1.kx[:,:, : n_rpe_lines // 2, :, :], trajectory2.kx[:,:, ::2, :, :]) + torch.testing.assert_close(trajectory1.ky[:,:, : n_rpe_lines // 2, :, :], trajectory2.ky[:,:, ::2, :, :]) + torch.testing.assert_close(trajectory1.kz[:,:, : n_rpe_lines // 2, :, :], trajectory2.kz[:,:, ::2, :, :]) def test_KTrajectoryRpe_shift(valid_rpe_kheader): @@ -169,15 +175,20 @@ def valid_cartesian_kheader(monkeypatch, random_kheader): n_k0 = 200 n_k1 = 20 n_k2 = 10 - n_other = 2 + n_other = (2,2) # List of k1 and k2 indices in the shape (other, k2, k1) k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32) k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32) idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy') - idx_k1 = repeat(torch.reshape(idx_k1, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) - idx_k2 = repeat(torch.reshape(idx_k2, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) - + #idx_k1 = repeat(torch.reshape(idx_k1, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) + idx_k1 = torch.reshape(idx_k1, (n_k2, n_k1)) + idx_k1 = idx_k1.repeat(*n_other, 1, 1) + + #idx_k2 = repeat(torch.reshape(idx_k2, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other) + idx_k2 = torch.reshape(idx_k2, (n_k2, n_k1)) + idx_k2 = idx_k2.repeat(*n_other, 1, 1) + # Set parameters for Cartesian trajectory monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0) monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2) @@ -193,11 +204,11 @@ def valid_cartesian_kheader(monkeypatch, random_kheader): def cartesian_traj_shape(valid_cartesian_kheader): """Expected shape of trajectory based on KHeader.""" - n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0, 0, 0] - n_k1 = valid_cartesian_kheader.acq_info.idx.k1.shape[2] - n_k2 = valid_cartesian_kheader.acq_info.idx.k1.shape[1] - n_other = 1 # trajectory along other is the same - return (torch.Size([n_other, n_k2, 1, 1]), torch.Size([n_other, 1, n_k1, 1]), torch.Size([n_other, 1, 1, n_k0])) + n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0,0,0,0] + n_k1 = valid_cartesian_kheader.acq_info.idx.k1.shape[-1] + n_k2 = valid_cartesian_kheader.acq_info.idx.k1.shape[-2] + n_other = (1,1) # trajectory along other is the same + return (torch.Size([*n_other, n_k2, 1, 1]), torch.Size([*n_other, 1, n_k1, 1]), torch.Size([*n_other, 1, 1, n_k0])) def test_KTrajectoryCartesian(valid_cartesian_kheader):