Skip to content

Commit

Permalink
Merge branch 'main' into separate_nufft_op
Browse files Browse the repository at this point in the history
  • Loading branch information
ckolbPTB authored Dec 14, 2024
2 parents b2d58fa + e6568dc commit d027e4a
Show file tree
Hide file tree
Showing 15 changed files with 176 additions and 23 deletions.
1 change: 1 addition & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ jobs:
uses: actions/download-artifact@v4
with:
path: ./docs/source/_notebooks/
merge-multiple: true

- name: Build docs
run: |
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ repos:
- id: mixed-line-ending

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2
rev: v0.8.1
hooks:
- id: ruff # linter
args: [--fix]
- id: ruff-format # formatter

- repo: https://github.com/crate-ci/typos
rev: v1.27.0
rev: typos-dict-v0.11.37
hooks:
- id: typos

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MRpro

![Python](https://img.shields.io/badge/python-3.11%20%7C%203.12-blue)
![Python](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
![Coverage Bagde](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/ckolbPTB/48e334a10caf60e6708d7c712e56d241/raw/coverage.json)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ All of the notebooks can directly be run via binder or colab from the repo.
:caption: Contents:
:glob:

_notebooks/*/*
_notebooks/*
15 changes: 14 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,20 @@ description = "MR image reconstruction and processing package specifically devel
readme = "README.md"
requires-python = ">=3.10,<3.14"
dynamic = ["version"]
keywords = ["MRI, reconstruction, processing, PyTorch"]
keywords = ["MRI",
"qMRI",
"medical imaging",
"physics-informed learning",
"model-based reconstruction",
"quantitative",
"signal models",
"machine learning",
"deep learning",
"reconstruction",
"processing",
"Pulseq",
"PyTorch",
]
authors = [
{ name = "MRpro Team", email = "info@emerpro.de" },
{ name = "Christoph Kolbitsch", email = "christoph.kolbitsch@ptb.de" },
Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.241126
0.241210
2 changes: 1 addition & 1 deletion src/mrpro/data/AcqInfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def tensor(data: np.ndarray) -> torch.Tensor:
data = data.astype(np.int32)
case np.uint32 | np.uint64:
data = data.astype(np.int64)
# Remove any uncessary dimensions
# Remove any unnecessary dimensions
return torch.tensor(np.squeeze(data))

def tensor_2d(data: np.ndarray) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/data/MoveDataMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def parse1(
) -> parsedType:
return device, dtype, non_blocking, copy, memory_format

if args and isinstance(args[0], torch.Tensor) or 'tensor' in kwargs:
if (args and isinstance(args[0], torch.Tensor)) or 'tensor' in kwargs:
# overload 3 ("tensor" specifies the dtype and device)
device, dtype, non_blocking, copy, memory_format = parse3(*args, **kwargs)
elif args and isinstance(args[0], torch.dtype):
Expand Down
15 changes: 10 additions & 5 deletions src/mrpro/data/traj_calculators/KTrajectoryPulseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,18 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape:
raise ValueError('We currently only support constant number of samples')
n_k0 = int(n_samples.item())

def reshape_pulseq_traj(k_traj: torch.Tensor, encoding_size: int):
k_traj *= encoding_size / (2 * torch.max(torch.abs(k_traj)))
def rescale_and_reshape_traj(k_traj: torch.Tensor, encoding_size: int):
if encoding_size > 1 and torch.max(torch.abs(k_traj)) > 0:
k_traj = k_traj * encoding_size / (2 * torch.max(torch.abs(k_traj)))
else:
# We force k_traj to be 0 if encoding_size = 1. This is typically the case for kz in 2D sequences.
# However, it happens that seq.calculate_kspace() returns values != 0 (numerical noise) in such cases.
k_traj = torch.zeros_like(k_traj)
return rearrange(k_traj, '(other k0) -> other k0', k0=n_k0)

# rearrange k-space trajectory to match MRpro convention
kx = reshape_pulseq_traj(k_traj_adc[0], kheader.encoding_matrix.x)
ky = reshape_pulseq_traj(k_traj_adc[1], kheader.encoding_matrix.y)
kz = reshape_pulseq_traj(k_traj_adc[2], kheader.encoding_matrix.z)
kx = rescale_and_reshape_traj(k_traj_adc[0], kheader.encoding_matrix.x)
ky = rescale_and_reshape_traj(k_traj_adc[1], kheader.encoding_matrix.y)
kz = rescale_and_reshape_traj(k_traj_adc[2], kheader.encoding_matrix.z)

return KTrajectoryRawShape(kz, ky, kx, self.repeat_detection_tolerance)
3 changes: 2 additions & 1 deletion src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from mrpro.utils.remove_repeat import remove_repeat
from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop
from mrpro.utils.split_idx import split_idx
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted
import mrpro.utils.unit_conversion

__all__ = [
"broadcast_right",
"fill_range_",
"reduce_view",
"remove_repeat",
"reshape_broadcasted",
"slice_profiles",
"smap",
"split_idx",
Expand Down
101 changes: 101 additions & 0 deletions src/mrpro/utils/reshape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Tensor reshaping utilities."""

from collections.abc import Sequence
from functools import lru_cache
from math import prod

import torch

Expand Down Expand Up @@ -99,3 +101,102 @@ def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torc
for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True))
]
return torch.as_strided(x, newsize, stride)


@lru_cache
def _reshape_idx(old_shape: tuple[int, ...], new_shape: tuple[int, ...], old_stride: tuple[int, ...]) -> list[slice]:
"""Get reshape reduce index (Cached helper function for reshape_broadcasted).
This function tries to group axes from new_shape and old_shape into the smallest groups that have
the same number of elements, starting from the right.
If all axes of old shape of a group are stride=0 dimensions, we can reduce them.
Example:
old_shape = (30, 2, 2, 3)
new_shape = (6, 5, 4, 3)
Will results in the groups (starting from the right):
- old: 3 new: 3
- old: 2, 2 new: 4
- old: 30 new: 6, 5
Only the "old" groups are important.
If all axes that are grouped together in an "old" group are stride 0 (=broadcasted)
we can collapse them to singleton dimensions.
This function returns the indexer that either collapses dimensions to singleton or keeps all
elements, i.e. the slices in the returned list are all either slice(1) or slice(None).
"""
idx = []
pointer_old, pointer_new = len(old_shape) - 1, len(new_shape) - 1 # start from the right
while pointer_old >= 0:
product_new, product_old = 1, 1 # the number of elements in the current "new" and "old" group
group: list[int] = []
while product_old != product_new or not group:
if product_old <= product_new:
# increase "old" group
product_old *= old_shape[pointer_old]
group.append(pointer_old)
pointer_old -= 1
else:
# increase "new" group
# we don't need to track the new group, the number of elemeents covered.
product_new *= new_shape[pointer_new]
pointer_new -= 1
# we found a group. now we need to decide what to do.
if all(old_stride[d] == 0 for d in group):
# all dimensions are broadcasted
# -> reduce to singleton
idx.extend([slice(1)] * len(group))
else:
# preserve dimension
idx.extend([slice(None)] * len(group))
idx = idx[::-1] # we worked right to left, but our index should be left to right
return idx


def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor:
"""Reshape a tensor while preserving broadcasted (stride 0) dimensions where possible.
Parameters
----------
tensor
The input tensor to reshape.
shape
The target shape for the tensor. One of the values can be `-1` and its size will be inferred.
Returns
-------
A tensor reshaped to the target shape, preserving broadcasted dimensions where feasible.
"""
try:
# if we can view the tensor directly, it will preserve broadcasting
return tensor.view(shape)
except RuntimeError:
# we cannot do a view, we need to do more work:

# -1 means infer size, i.e. the remaining elements of the input not already covered by the other axes.
negative_ones = shape.count(-1)
size = tensor.shape.numel()
if not negative_ones:
if prod(shape) != size:
# use same exception as pytorch
raise RuntimeError(f"shape '{list(shape)}' is invalid for input of size {size}") from None
elif negative_ones > 1:
raise RuntimeError('only one dimension can be inferred') from None
elif negative_ones == 1:
# we need to figure out the size of the "-1" dimension
known_size = -prod(shape) # negative, is it includes the -1
if size % known_size:
# non integer result. no possible size of the -1 axis exists.
raise RuntimeError(f"shape '{list(shape)}' is invalid for input of size {size}") from None
shape = tuple(size // known_size if s == -1 else s for s in shape)

# most of the broadcasted dimensions can be preserved: only dimensions that are joined with non
# broadcasted dimensions can not be preserved and must be made contiguous.
# all dimensions that can be preserved as broadcasted are first collapsed to singleton,
# such that contiguous does not create copies along these axes.
idx = _reshape_idx(tensor.shape, shape, tensor.stride())
# make contiguous only in dimensions in which broadcasting cannot be preserved
semicontiguous = tensor[idx].contiguous()
# finally, we can expand the broadcasted dimensions to the requested shape
semicontiguous = semicontiguous.expand(tensor.shape)
return semicontiguous.view(shape)
2 changes: 1 addition & 1 deletion src/mrpro/utils/slice_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from torch import Tensor

__all__ = ['SliceProfileBase', 'SliceGaussian', 'SliceSmoothedRectangular', 'SliceInterpolate']
__all__ = ['SliceGaussian', 'SliceInterpolate', 'SliceProfileBase', 'SliceSmoothedRectangular']


class SliceProfileBase(abc.ABC, torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
NestedSequence: TypeAlias = Any
NumpyIndexerType: TypeAlias = Any

__all__ = ['TorchIndexerType', 'NumpyIndexerType', 'NestedSequence']
__all__ = ['NestedSequence', 'NumpyIndexerType', 'TorchIndexerType']
12 changes: 6 additions & 6 deletions src/mrpro/utils/unit_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import torch

__all__ = [
'ms_to_s',
's_to_ms',
'mm_to_m',
'm_to_mm',
'GYROMAGNETIC_RATIO_PROTON',
'deg_to_rad',
'rad_to_deg',
'lamor_frequency_to_magnetic_field',
'm_to_mm',
'magnetic_field_to_lamor_frequency',
'GYROMAGNETIC_RATIO_PROTON',
'mm_to_m',
'ms_to_s',
'rad_to_deg',
's_to_ms',
]

GYROMAGNETIC_RATIO_PROTON = 42.58 * 1e6
Expand Down
34 changes: 33 additions & 1 deletion tests/utils/test_reshape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Tests for reshaping utilities."""

import pytest
import torch
from mrpro.utils import broadcast_right, reduce_view, unsqueeze_left, unsqueeze_right
from mrpro.utils import broadcast_right, reduce_view, reshape_broadcasted, unsqueeze_left, unsqueeze_right

from tests import RandomGenerator

Expand Down Expand Up @@ -51,3 +52,34 @@ def test_reduce_view():
reduced_one_pos = reduce_view(tensor, 0)
assert reduced_one_pos.shape == (1, 2, 3, 4, 5, 6)
assert torch.equal(reduced_one_pos.expand_as(tensor), tensor)


@pytest.mark.parametrize(
('shape', 'expand_shape', 'permute', 'final_shape', 'expected_stride'),
[
((1, 2, 3, 1, 1), (1, 2, 3, 4, 5), (0, 2, 1, 3, 4), (1, 6, 2, 2, 5), (6, 1, 0, 0, 0)),
((1, 2, 1), (100, 2, 2), (0, 1, 2), (100, 4), (0, 1)),
((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 0, 1), (1, 2, 6, 10, 1), (0, 0, 0, 0, 0)),
((1, 2, 3), (1, -1, 3), (0, 1, 2), (6,), (1,)),
],
)
def test_reshape_broadcasted(shape, expand_shape, permute, final_shape, expected_stride):
"""Test reshape_broadcasted"""
rng = RandomGenerator(0)
tensor = rng.float32_tensor(shape).expand(*expand_shape).permute(*permute)
reshaped = reshape_broadcasted(tensor, *final_shape)
expected_values = tensor.reshape(*final_shape)
assert reshaped.shape == expected_values.shape
assert reshaped.stride() == expected_stride
assert torch.equal(reshaped, expected_values)


def test_reshape_broadcasted_fail():
"""Test reshape_broadcasted with invalid input"""
a = torch.ones(2)
with pytest.raises(RuntimeError, match='invalid'):
reshape_broadcasted(a, 3)
with pytest.raises(RuntimeError, match='invalid'):
reshape_broadcasted(a, -1, -3)
with pytest.raises(RuntimeError, match='only one dimension'):
reshape_broadcasted(a, -1, -1)

0 comments on commit d027e4a

Please sign in to comment.