Skip to content

Commit

Permalink
some inital changes on other dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
guastinimara committed Jul 10, 2024
1 parent 38eeecc commit d014b72
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 48 deletions.
18 changes: 10 additions & 8 deletions src/mrpro/data/IData.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def from_single_dicom(cls, filename: str | Path) -> IData:
dataset = dcmread(filename)
idata = _dcm_pixelarray_to_tensor(dataset)[None, :]
idata = rearrange(idata, '(other coils z) y x -> other coils z y x', other=1, coils=1, z=1)

header = IHeader.from_dicom_list([dataset])
return cls(data=idata, header=header)

Expand Down Expand Up @@ -146,13 +146,15 @@ def get_unique_slice_positions(slice_pos_tag: TagType = 0x00191015):
raise ValueError('Only dicoms with the same orientation can be read in.')
# stack required due to mypy: einops rearrange list[tensor]->tensor not recognized
idata = torch.stack([_dcm_pixelarray_to_tensor(ds) for ds in dataset_list])
idata = rearrange(
idata,
'(other coils z) y x -> other coils z y x',
other=len(idata),
coils=1,
z=1,
)
# idata = rearrange(
# idata,
# '(other coils z) y x -> other coils z y x',
# other=len(idata),
# coils=1,
# z=1,
# )

idata = idata.reshape(*idata.shape[:-2],1,1,idata.shape[-2],idata.shape[-1])

header = IHeader.from_dicom_list(dataset_list)
return cls(data=idata, header=header)
Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/data/KNoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ def from_file(
noise_data = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions])

# Reshape to standard dimensions
noise_data = rearrange(noise_data, 'other coils (k2 k1 k0)->other coils k2 k1 k0', k1=1, k2=1)
noise_data = rearrange(noise_data, '... coils (k2 k1 k0)-> ... coils k2 k1 k0', k1=1, k2=1)

return cls(noise_data)
4 changes: 2 additions & 2 deletions src/mrpro/data/traj_calculators/KTrajectoryCartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ def __call__(self, kheader: KHeader) -> KTrajectory:
kz = (kheader.acq_info.idx.k2 - kheader.encoding_limits.k2.center).to(torch.float32)

# Bring to correct dimensions
ky = repeat(ky, 'other k2 k1-> other k2 k1 k0', k0=1)
kz = repeat(kz, 'other k2 k1-> other k2 k1 k0', k0=1)
ky = repeat(ky, '... k2 k1-> ... k2 k1 k0', k0=1)
kz = repeat(kz, '... k2 k1-> ... k2 k1 k0', k0=1)
return KTrajectory(kz, ky, kx)
2 changes: 1 addition & 1 deletion src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ def __call__(self, kheader: KHeader) -> KTrajectory:
# K-space cartesian coordinates
kx = krad * torch.cos(kang)[..., None]
ky = krad * torch.sin(kang)[..., None]
kz = torch.zeros(1, 1, 1, 1)
kz = torch.zeros(kx.dim() * (1,))

return KTrajectory(kz, ky, kx)
8 changes: 4 additions & 4 deletions src/mrpro/operators/CartesianSamplingOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
if not self._needs_indexing:
return (x,)

x_kflat = rearrange(x, '... coil k2_enc k1_enc k0_enc -> ... coil (k2_enc k1_enc k0_enc)')
x_kflat = rearrange(x, '... coils k2_enc k1_enc k0_enc -> ... coils (k2_enc k1_enc k0_enc)')
# take_along_dim does broadcast, so no need for extending here
x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1)
# reshape to (... other coil, k2, k1, k0)
# reshape to (... other coils, k2, k1, k0)
x_reshaped = x_indexed.reshape(x.shape[:-3] + self._trajectory_shape[-3:])

return (x_reshaped,)
Expand All @@ -130,7 +130,7 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
if not self._needs_indexing:
return (y,)

y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)')
y_kflat = rearrange(y, '... coils k2 k1 k0 -> ... coils (k2 k1 k0)')

# scatter does not broadcast, so we need to manually broadcast the indices
broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1])
Expand All @@ -145,7 +145,7 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
device=y.device,
).scatter_(dim=-1, index=idx_expanded, src=y_kflat)

# reshape to ..., other, coil, k2_enc, k1_enc, k0_enc
# reshape to ..., other, coils, k2_enc, k1_enc, k0_enc
y_reshaped = y_scattered.reshape(
*y.shape[:-3],
self._sorted_grid_shape.z,
Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/operators/DensityCompensationOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ def __init__(self, dcf: DcfData | torch.Tensor) -> None:
dcf_tensor = dcf.data
else:
dcf_tensor = dcf
super().__init__(dcf_tensor, '... k2 k1 k0 ,... coil k2 k1 k0 ->... coil k2 k1 k0')
super().__init__(dcf_tensor, '... k2 k1 k0 ,... coils k2 k1 k0 ->... coils k2 k1 k0')
73 changes: 42 additions & 31 deletions tests/data/test_traj_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def valid_rad2d_kheader(monkeypatch, random_kheader):
n_k0 = 256
n_k1 = 10
n_k2 = 1
n_other = (1,1)

# List of k1 indices in the shape
idx_k1 = torch.arange(n_k1, dtype=torch.int32)[None, None, ...]
# List of k1 indices in the shape (other, 1, k1)
idx_k1 = torch.arange(n_k1, dtype=torch.int32).repeat(*n_other, 1, 1)

# Set parameters for radial 2D trajectory
monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0)
Expand All @@ -56,14 +57,14 @@ def valid_rad2d_kheader(monkeypatch, random_kheader):

def radial2D_traj_shape(valid_rad2d_kheader):
"""Expected shape of trajectory based on KHeader."""
n_k0 = valid_rad2d_kheader.acq_info.number_of_samples[0, 0, 0]
n_k1 = valid_rad2d_kheader.acq_info.idx.k1.shape[2]
n_k0 = valid_rad2d_kheader.acq_info.number_of_samples[0, 0, 0, 0]
n_k1 = valid_rad2d_kheader.acq_info.idx.k1.shape[-1]
n_k2 = 1
n_other = 1
n_other = (1,1)
return (
torch.Size([n_other, 1, 1, 1]),
torch.Size([n_other, n_k2, n_k1, n_k0]),
torch.Size([n_other, n_k2, n_k1, n_k0]),
torch.Size([*n_other, 1, 1, 1]),
torch.Size([*n_other, n_k2, n_k1, n_k0]),
torch.Size([*n_other, n_k2, n_k1, n_k0]),
)


Expand All @@ -84,13 +85,17 @@ def valid_rpe_kheader(monkeypatch, random_kheader):
n_k0 = 200
n_k1 = 20
n_k2 = 10

n_other= (1,1)

# List of k1 and k2 indices in the shape (other, k2, k1)
k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32)
k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32)
idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy')
idx_k1 = torch.reshape(idx_k1, (1, n_k2, n_k1))
idx_k2 = torch.reshape(idx_k2, (1, n_k2, n_k1))
idx_k1 = torch.reshape(idx_k1, (n_k2, n_k1))
idx_k1 = idx_k1.repeat(*n_other, 1, 1)
idx_k2 = torch.reshape(idx_k2, (n_k2, n_k1))
idx_k2 = idx_k2.repeat(*n_other, 1, 1)

# Set parameters for RPE trajectory
monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0)
Expand All @@ -105,14 +110,14 @@ def valid_rpe_kheader(monkeypatch, random_kheader):

def rpe_traj_shape(valid_rpe_kheader):
"""Expected shape of trajectory based on KHeader."""
n_k0 = valid_rpe_kheader.acq_info.number_of_samples[0, 0, 0]
n_k1 = valid_rpe_kheader.acq_info.idx.k1.shape[2]
n_k2 = valid_rpe_kheader.acq_info.idx.k1.shape[1]
n_other = 1
n_k0 = valid_rpe_kheader.acq_info.number_of_samples[0, 0, 0, 0]
n_k1 = valid_rpe_kheader.acq_info.idx.k1.shape[-1]
n_k2 = valid_rpe_kheader.acq_info.idx.k1.shape[-2]
n_other = (1,1)
return (
torch.Size([n_other, n_k2, n_k1, 1]),
torch.Size([n_other, n_k2, n_k1, 1]),
torch.Size([n_other, 1, 1, n_k0]),
torch.Size([*n_other, n_k2, n_k1, 1]),
torch.Size([*n_other, n_k2, n_k1, 1]),
torch.Size([*n_other, 1, 1, n_k0]),
)


Expand All @@ -128,7 +133,7 @@ def test_KTrajectoryRpe_golden(valid_rpe_kheader):

def test_KTrajectoryRpe_uniform(valid_rpe_kheader):
"""Calculate RPE trajectory with uniform angle."""
n_rpe_lines = valid_rpe_kheader.acq_info.idx.k1.shape[1]
n_rpe_lines = valid_rpe_kheader.acq_info.idx.k1.shape[-2]
trajectory1_calculator = KTrajectoryRpe(angle=torch.pi / n_rpe_lines, shift_between_rpe_lines=torch.tensor([0]))
trajectory1 = trajectory1_calculator(valid_rpe_kheader)
# Calculate trajectory with half the angular gap such that every second line should be the same as above
Expand All @@ -137,10 +142,11 @@ def test_KTrajectoryRpe_uniform(valid_rpe_kheader):
shift_between_rpe_lines=torch.tensor([0]),
)
trajectory2 = trajectory2_calculator(valid_rpe_kheader)

torch.testing.assert_close(trajectory1.kx[:, : n_rpe_lines // 2, :, :], trajectory2.kx[:, ::2, :, :])
torch.testing.assert_close(trajectory1.ky[:, : n_rpe_lines // 2, :, :], trajectory2.ky[:, ::2, :, :])
torch.testing.assert_close(trajectory1.kz[:, : n_rpe_lines // 2, :, :], trajectory2.kz[:, ::2, :, :])


torch.testing.assert_close(trajectory1.kx[:,:, : n_rpe_lines // 2, :, :], trajectory2.kx[:,:, ::2, :, :])
torch.testing.assert_close(trajectory1.ky[:,:, : n_rpe_lines // 2, :, :], trajectory2.ky[:,:, ::2, :, :])
torch.testing.assert_close(trajectory1.kz[:,:, : n_rpe_lines // 2, :, :], trajectory2.kz[:,:, ::2, :, :])


def test_KTrajectoryRpe_shift(valid_rpe_kheader):
Expand Down Expand Up @@ -169,15 +175,20 @@ def valid_cartesian_kheader(monkeypatch, random_kheader):
n_k0 = 200
n_k1 = 20
n_k2 = 10
n_other = 2
n_other = (2,2)

# List of k1 and k2 indices in the shape (other, k2, k1)
k1 = torch.linspace(0, n_k1 - 1, n_k1, dtype=torch.int32)
k2 = torch.linspace(0, n_k2 - 1, n_k2, dtype=torch.int32)
idx_k1, idx_k2 = torch.meshgrid(k1, k2, indexing='xy')
idx_k1 = repeat(torch.reshape(idx_k1, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other)
idx_k2 = repeat(torch.reshape(idx_k2, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other)

#idx_k1 = repeat(torch.reshape(idx_k1, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other)
idx_k1 = torch.reshape(idx_k1, (n_k2, n_k1))
idx_k1 = idx_k1.repeat(*n_other, 1, 1)

#idx_k2 = repeat(torch.reshape(idx_k2, (n_k2, n_k1)), 'k2 k1->other k2 k1', other=n_other)
idx_k2 = torch.reshape(idx_k2, (n_k2, n_k1))
idx_k2 = idx_k2.repeat(*n_other, 1, 1)

# Set parameters for Cartesian trajectory
monkeypatch.setattr(random_kheader.acq_info, 'number_of_samples', torch.zeros_like(idx_k1)[..., None] + n_k0)
monkeypatch.setattr(random_kheader.acq_info, 'center_sample', torch.zeros_like(idx_k1)[..., None] + n_k0 // 2)
Expand All @@ -193,11 +204,11 @@ def valid_cartesian_kheader(monkeypatch, random_kheader):

def cartesian_traj_shape(valid_cartesian_kheader):
"""Expected shape of trajectory based on KHeader."""
n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0, 0, 0]
n_k1 = valid_cartesian_kheader.acq_info.idx.k1.shape[2]
n_k2 = valid_cartesian_kheader.acq_info.idx.k1.shape[1]
n_other = 1 # trajectory along other is the same
return (torch.Size([n_other, n_k2, 1, 1]), torch.Size([n_other, 1, n_k1, 1]), torch.Size([n_other, 1, 1, n_k0]))
n_k0 = valid_cartesian_kheader.acq_info.number_of_samples[0,0,0,0]
n_k1 = valid_cartesian_kheader.acq_info.idx.k1.shape[-1]
n_k2 = valid_cartesian_kheader.acq_info.idx.k1.shape[-2]
n_other = (1,1) # trajectory along other is the same
return (torch.Size([*n_other, n_k2, 1, 1]), torch.Size([*n_other, 1, n_k1, 1]), torch.Size([*n_other, 1, 1, n_k0]))


def test_KTrajectoryCartesian(valid_cartesian_kheader):
Expand Down

0 comments on commit d014b72

Please sign in to comment.