Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow more than one 'other' dimension #359

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mrpro/algorithms/prewhiten_kspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def prewhiten_kspace(kdata: KData, knoise: KNoise, scale_factor: float | torch.T

# solve_triangular is numerically more stable than inverting the matrix
# but requires a single batch dimension
kdata_flat = rearrange(kdata.data, '... coil k2 k1 k0 -> (... k2 k1 k0) coil 1')
kdata_flat = rearrange(kdata.data, '... coils k2 k1 k0 -> (... k2 k1 k0) coils 1')
whitened_flat = torch.linalg.solve_triangular(cholsky, kdata_flat, upper=False)
whitened_flatother = rearrange(
whitened_flat, '(other k2 k1 k0) coil 1-> other coil k2 k1 k0', **parse_shape(kdata.data, '... k2 k1 k0')
Expand Down
9 changes: 5 additions & 4 deletions src/mrpro/data/CsmData.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

import torch
from einops import rearrange
from typing_extensions import Self

from mrpro.data.IData import IData
Expand Down Expand Up @@ -49,8 +50,8 @@ def from_idata_walsh(
lambda img: walsh(img, smoothing_width),
chunk_size=chunk_size_otherdim,
)
csm_tensor = csm_fun(idata.data)
csm = cls(header=idata.header, data=csm_tensor)
csm_tensor = csm_fun(rearrange(idata.data, '... coils z y x->(...) coils z y x'))
csm = cls(header=idata.header, data=csm_tensor.reshape(idata.data.shape))
return csm

@classmethod
Expand All @@ -75,8 +76,8 @@ def from_idata_inati(
from mrpro.algorithms.csm.inati import inati

csm_fun = torch.vmap(lambda img: inati(img, smoothing_width), chunk_size=chunk_size_otherdim)
csm_tensor = csm_fun(idata.data)
csm = cls(header=idata.header, data=csm_tensor)
csm_tensor = csm_fun(rearrange(idata.data, '... coils z y x->(...) coils z y x'))
csm = cls(header=idata.header, data=csm_tensor.reshape(idata.data.shape))
return csm

def as_operator(self) -> SensitivityOp:
Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/data/IData.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def from_tensor_and_kheader(cls, data: torch.Tensor, kheader: KHeader) -> Self:
Parameters
----------
data
image data with dimensions (broadcastable to) `(other, coils, z, y, x)`.
image data with dimensions (broadcastable to) `(*other, coils, z, y, x)`.
kheader
MR raw data header containing required meta data for the image header.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/data/KNoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def from_file(
noise_data = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions])

# Reshape to standard dimensions
noise_data = repeat(noise_data, '... coils k0->... coils k2 k1 k0', k1=1, k2=1)
noise_data = repeat(noise_data, '... k0 -> ... k2 k1 k0', k1=1, k2=1)

return cls(noise_data)

Expand Down
3 changes: 0 additions & 3 deletions src/mrpro/data/KTrajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
class KTrajectory(MoveDataMixin):
"""K-space trajectory.

Contains the trajectory in k-space along the three dimensions `kz`, `ky`, `kx`,
i.e. describes where in k-space each data point was acquired.

The shape of each of `kx`, `ky`, `kz` is `(*other, k2, k1, k0)`,
where `other` can span multiple dimensions.

Expand Down
6 changes: 3 additions & 3 deletions src/mrpro/data/KTrajectoryRawShape.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def sort_and_reshape(
KTrajectory with kx, ky and kz each in the shape `(other k2 k1 k0)`.
"""
# Resort and reshape
kz = rearrange(self.kz[sort_idx, ...], '(other k2 k1) k0 -> other k2 k1 k0', k1=n_k1, k2=n_k2)
ky = rearrange(self.ky[sort_idx, ...], '(other k2 k1) k0 -> other k2 k1 k0', k1=n_k1, k2=n_k2)
kx = rearrange(self.kx[sort_idx, ...], '(other k2 k1) k0 -> other k2 k1 k0', k1=n_k1, k2=n_k2)
kz = rearrange(self.kz[sort_idx, ...], '... (other k2 k1) k0 -> ... other k2 k1 k0', k1=n_k1, k2=n_k2)
ky = rearrange(self.ky[sort_idx, ...], '... (other k2 k1) k0 -> ... other k2 k1 k0', k1=n_k1, k2=n_k2)
kx = rearrange(self.kx[sort_idx, ...], '... (other k2 k1) k0 -> ... other k2 k1 k0', k1=n_k1, k2=n_k2)

return KTrajectory(kz, ky, kx, repeat_detection_tolerance=self.repeat_detection_tolerance)
2 changes: 1 addition & 1 deletion src/mrpro/data/QData.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, data: torch.Tensor, header: KHeader | IHeader | QHeader) -> N
Parameters
----------
data
quantitative image data tensor with dimensions `(other, coils, z, y, x)`
quantitative image data tensor with dimensions `(*other, coils, z, y, x)`
header
MRpro header containing required meta data for the QHeader
"""
Expand Down
3 changes: 1 addition & 2 deletions src/mrpro/data/traj_calculators/KTrajectoryRadial2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def __call__(self, kheader: KHeader) -> KTrajectory:
# Angles of readout lines
kang = repeat(kheader.acq_info.idx.k1 * self.angle, '... k2 k1 -> ... k2 k1 k0', k0=1)

# K-space radial coordinates
kx = krad * torch.cos(kang)
ky = krad * torch.sin(kang)
kz = torch.zeros(1, 1, 1, 1)
kz = torch.zeros(kx.dim() * (1,))

return KTrajectory(kz, ky, kx)
4 changes: 2 additions & 2 deletions src/mrpro/operators/CartesianSamplingOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
if not self._needs_indexing:
return (y,)

y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)')
y_kflat = rearrange(y, '... k2 k1 k0 -> ... (k2 k1 k0)')

if self._inside_encoding_matrix_idx is not None:
idx = unsqueeze_left(self._inside_encoding_matrix_idx, y_kflat.ndim - self._inside_encoding_matrix_idx.ndim)
Expand All @@ -169,7 +169,7 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
y_kflat, self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, self._fft_idx
)

# reshape to ..., other, coil, k2_enc, k1_enc, k0_enc
# reshape to ..., other, coils, k2_enc, k1_enc, k0_enc
y_reshaped = y_scattered.reshape(
*y.shape[:-3],
self._sorted_grid_shape.z,
Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/operators/DensityCompensationOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def __init__(self, dcf: DcfData | torch.Tensor) -> None:
dcf_tensor = dcf.data
else:
dcf_tensor = dcf
super().__init__(dcf_tensor, '... k2 k1 k0 ,... coil k2 k1 k0 ->... coil k2 k1 k0')
super().__init__(dcf_tensor, '... k2 k1 k0 ,... coils k2 k1 k0 ->... coils k2 k1 k0')
8 changes: 4 additions & 4 deletions src/mrpro/operators/SensitivityOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def forward(self, img: torch.Tensor) -> tuple[torch.Tensor,]:
Parameters
----------
img
image data tensor with dimensions `(other 1 z y x)`.
image data tensor with dimensions `(*other 1 z y x)`.

Returns
-------
image data tensor with dimensions `(other coils z y x)`.
image data tensor with dimensions `(*other coils z y x)`.
"""
return (self.csm_tensor * img,)

Expand All @@ -47,10 +47,10 @@ def adjoint(self, img: torch.Tensor) -> tuple[torch.Tensor,]:
Parameters
----------
img
image data tensor with dimensions `(other coils z y x)`.
image data tensor with dimensions `(*other coils z y x)`.

Returns
-------
image data tensor with dimensions `(other 1 z y x)`.
image data tensor with dimensions `(*other 1 z y x)`.
"""
return ((self.csm_tensor.conj() * img).sum(-4, keepdim=True),)
6 changes: 5 additions & 1 deletion tests/algorithms/csm/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""PyTest fixtures for the csm tests."""

import torch
from mrpro.data import IData, SpatialDimension
from mrpro.phantoms.coils import birdcage_2d


def multi_coil_image(n_coils, ph_ellipse, random_kheader):
def multi_coil_image(n_coils, ph_ellipse, random_kheader, n_other=(1,)):
"""Create multi-coil image."""
image_dimensions = SpatialDimension(z=1, y=ph_ellipse.n_y, x=ph_ellipse.n_x)

Expand All @@ -15,5 +16,8 @@ def multi_coil_image(n_coils, ph_ellipse, random_kheader):
img = ph_ellipse.phantom.image_space(image_dimensions)
# +1 to ensure that there is signal everywhere, for voxel == 0 csm cannot be determined.
img_multi_coil = (img + 1) * csm_ref

# Repeat data for multiple other dimensions
img_multi_coil = torch.tile(img_multi_coil, n_other + (1,) * 4)
idata = IData.from_tensor_and_kheader(data=img_multi_coil, kheader=random_kheader)
return (idata, csm_ref)
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,11 @@ def random_kheader_shape(request, random_acquisition, random_full_ismrmrd_header

def create_uniform_traj(nk):
"""Create a tensor of uniform points with predefined shape nk."""
kidx = torch.where(torch.tensor(nk[1:]) > 1)[0]
kidx = torch.where(torch.tensor(nk[-3:]) > 1)[0]
if len(kidx) > 1:
raise ValueError('nk is allowed to have at most one non-singleton dimension')
if len(kidx) >= 1:
# kidx+1 because we searched in nk[1:]
n_kpoints = nk[kidx + 1]
n_kpoints = nk[-3 + kidx]
k = torch.linspace(-n_kpoints // 2, n_kpoints // 2 - 1, n_kpoints, dtype=torch.float32)
views = [1 if i != n_kpoints else -1 for i in nk]
k = k.view(*views).expand(list(nk))
Expand Down
4 changes: 2 additions & 2 deletions tests/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def random_mandatory_ismrmrd_header(request) -> xsd.ismrmrdschema.ismrmrdHeader:
return xsd.ismrmrdschema.ismrmrdHeader(encoding=[encoding], experimentalConditions=experimental_conditions)


@pytest.fixture(params=({'seed': 0, 'n_other': 2, 'n_coils': 8, 'n_z': 16, 'n_y': 32, 'n_x': 64},))
@pytest.fixture(params=({'seed': 0, 'n_other': (2, 3), 'n_coils': 8, 'n_z': 16, 'n_y': 32, 'n_x': 64},))
def random_test_data(request):
seed, n_other, n_coils, n_z, n_y, n_x = (
request.param['seed'],
Expand All @@ -62,7 +62,7 @@ def random_test_data(request):
request.param['n_x'],
)
generator = RandomGenerator(seed)
test_data = generate_random_data(generator, (n_other, n_coils, n_z, n_y, n_x))
test_data = generate_random_data(generator, (*n_other, n_coils, n_z, n_y, n_x))
return test_data


Expand Down
19 changes: 18 additions & 1 deletion tests/data/test_csm_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,28 @@ def test_CsmData_smoothing_width(csm_method, ellipse_phantom, random_kheader):
assert torch.equal(csm_using_spatial_dimension.data, csm_using_int.data)


@pytest.mark.parametrize('csm_method', [CsmData.from_idata_walsh, CsmData.from_idata_inati])
def test_CsmData_cpu(csm_method, ellipse_phantom, random_kheader):
"""CsmData obtained on CPU."""
idata, csm_ref = multi_coil_image(
n_coils=4, ph_ellipse=ellipse_phantom, random_kheader=random_kheader, n_other=(2, 1)
)

# Estimate coil sensitivity maps
smoothing_width = SpatialDimension(z=1, y=5, x=5)
csm = csm_method(idata, smoothing_width)

# Phase is only relative in csm calculation, therefore only the abs values are compared.
assert relative_image_difference(torch.abs(csm.data), torch.abs(csm_ref)) <= 0.01


@pytest.mark.cuda
@pytest.mark.parametrize('csm_method', [CsmData.from_idata_walsh, CsmData.from_idata_inati])
def test_CsmData_cuda(csm_method, ellipse_phantom, random_kheader):
"""CsmData obtained on GPU in CUDA memory."""
idata, csm_ref = multi_coil_image(n_coils=4, ph_ellipse=ellipse_phantom, random_kheader=random_kheader)
idata, csm_ref = multi_coil_image(
n_coils=4, ph_ellipse=ellipse_phantom, random_kheader=random_kheader, n_other=(2, 1)
)

# Estimate coil sensitivity maps
smoothing_width = SpatialDimension(z=1, y=5, x=5)
Expand Down
11 changes: 11 additions & 0 deletions tests/data/test_dcf_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,14 @@ def test_dcf_broadcast():
trajectory = KTrajectory(kz, ky, kx)
dcf = DcfData.from_traj_voronoi(trajectory)
assert dcf.data.shape == trajectory.broadcasted_shape


def test_dcf_multi_other():
"""Test voronoi dcf calculation for multiple other dimensions."""
rng = RandomGenerator(0)
kx = rng.float32_tensor((2, 3, 1, 4, 4))
ky = rng.float32_tensor((2, 3, 1, 4, 4))
kz = torch.zeros(2, 3, 1, 1, 1)
trajectory = KTrajectory(kz, ky, kx)
dcf = DcfData.from_traj_voronoi(trajectory)
assert dcf.data.shape == trajectory.broadcasted_shape
1 change: 1 addition & 0 deletions tests/data/test_idata.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_IData_from_dcm_files(dcm_2d_multi_echo_times_multi_folders):
def test_IData_from_kheader_and_tensor(random_kheader, random_test_data):
"""IData from KHeader and data tensor."""
idata = IData.from_tensor_and_kheader(data=random_test_data, kheader=random_kheader)
assert idata.data.shape == random_test_data.shape
assert idata.header.te == random_kheader.te


Expand Down
14 changes: 13 additions & 1 deletion tests/data/test_knoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@
from mrpro.data import KNoise


def test_knoise_from_tensor(random_test_data):
"""Create KNoise from tensor."""
noise = KNoise(data=random_test_data)
assert noise.data.shape == random_test_data.shape


def test_knoise_from_file(ismrmrd_cart):
"""Create KNoise from file."""
knoise = KNoise.from_file(ismrmrd_cart.filename)
assert knoise is not None


def test_knoise_to_complex128(random_test_data):
"""Change dtype to complex128."""
noise = KNoise(data=random_test_data).to(dtype=torch.complex128)
Expand All @@ -13,7 +25,7 @@ def test_knoise_to_complex128(random_test_data):

@pytest.mark.cuda
def test_knoise_cuda(random_test_data):
"""Move KNois object to CUDA memory."""
"""Move KNoise object to CUDA memory."""
noise = KNoise(data=random_test_data).cuda()
assert noise.data.is_cuda

Expand Down
1 change: 1 addition & 0 deletions tests/data/test_qdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
def test_QData_from_kheader_and_tensor(random_kheader, random_test_data):
"""QData from KHeader and data tensor."""
qdata = QData(data=random_test_data, header=random_kheader)
assert qdata.data.shape == random_test_data.shape
assert qdata.header.fov == random_kheader.recon_fov


Expand Down
Loading