Skip to content

Commit

Permalink
Merge branch 'main' into binder-cache
Browse files Browse the repository at this point in the history
  • Loading branch information
lrlunin authored Dec 11, 2024
2 parents 26bd4ef + e6568dc commit 821dbe6
Show file tree
Hide file tree
Showing 25 changed files with 341 additions and 42 deletions.
1 change: 1 addition & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ jobs:
uses: actions/download-artifact@v4
with:
path: ./docs/source/_notebooks/
merge-multiple: true

- name: Build docs
run: |
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@ repos:
- id: mixed-line-ending

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2
rev: v0.8.1
hooks:
- id: ruff # linter
args: [--fix]
- id: ruff-format # formatter

- repo: https://github.com/crate-ci/typos
rev: v1.27.0
rev: typos-dict-v0.11.37
hooks:
- id: typos

- repo: https://github.com/fzimmermann89/check_all
rev: v1.0
rev: v1.1
hooks:
- id: check-init-all
args: [--double-quotes]
args: [--double-quotes, --fix]
exclude: ^tests/

- repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MRpro

![Python](https://img.shields.io/badge/python-3.11%20%7C%203.12-blue)
![Python](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
![Coverage Bagde](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/ckolbPTB/48e334a10caf60e6708d7c712e56d241/raw/coverage.json)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ All of the notebooks can directly be run via binder or colab from the repo.
:caption: Contents:
:glob:

_notebooks/*/*
_notebooks/*
23 changes: 19 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,20 @@ description = "MR image reconstruction and processing package specifically devel
readme = "README.md"
requires-python = ">=3.10,<3.14"
dynamic = ["version"]
keywords = ["MRI, reconstruction, processing, PyTorch"]
keywords = ["MRI",
"qMRI",
"medical imaging",
"physics-informed learning",
"model-based reconstruction",
"quantitative",
"signal models",
"machine learning",
"deep learning",
"reconstruction",
"processing",
"Pulseq",
"PyTorch",
]
authors = [
{ name = "MRpro Team", email = "info@emerpro.de" },
{ name = "Christoph Kolbitsch", email = "christoph.kolbitsch@ptb.de" },
Expand Down Expand Up @@ -183,12 +196,14 @@ skip-magic-trailing-comma = false

[tool.typos.default]
locale = "en-us"
check-filename = false

[tool.typos.default.extend-words]
Reson = "Reson" # required for Proc. Intl. Soc. Mag. Reson. Med.
Reson = "Reson" # required for Proc. Intl. Soc. Mag. Reson. Med.
iy = "iy"
daa = 'daa' # required for wavelet operator
gaus = 'gaus' # required for wavelet operator
daa = "daa" # required for wavelet operator
gaus = "gaus" # required for wavelet operator
arange = "arange" # torch.arange

[tool.typos.files]
extend-exclude = [
Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.241112
0.241210
2 changes: 1 addition & 1 deletion src/mrpro/data/AcqInfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def tensor(data: np.ndarray) -> torch.Tensor:
data = data.astype(np.int32)
case np.uint32 | np.uint64:
data = data.astype(np.int64)
# Remove any uncessary dimensions
# Remove any unnecessary dimensions
return torch.tensor(np.squeeze(data))

def tensor_2d(data: np.ndarray) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion src/mrpro/data/DcfData.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def from_traj_voronoi(cls, traj: KTrajectory) -> Self:

if ks_needing_voronoi:
# Handle full dimensions needing voronoi
dcfs.append(smap(dcf_2d3d_voronoi, torch.stack(list(ks_needing_voronoi), -4), 4))
dcfs.append(smap(dcf_2d3d_voronoi, torch.stack(torch.broadcast_tensors(*ks_needing_voronoi), -4), 4))

if dcfs:
# Multiply all dcfs together
Expand Down
20 changes: 19 additions & 1 deletion src/mrpro/data/MoveDataMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def parse1(
) -> parsedType:
return device, dtype, non_blocking, copy, memory_format

if args and isinstance(args[0], torch.Tensor) or 'tensor' in kwargs:
if (args and isinstance(args[0], torch.Tensor)) or 'tensor' in kwargs:
# overload 3 ("tensor" specifies the dtype and device)
device, dtype, non_blocking, copy, memory_format = parse3(*args, **kwargs)
elif args and isinstance(args[0], torch.dtype):
Expand Down Expand Up @@ -239,6 +239,24 @@ def _convert(data: T) -> T:
new.apply_(_convert, memo=memo, recurse=False)
return new

def apply(
self: Self,
function: Callable[[Any], Any] | None = None,
*,
recurse: bool = True,
) -> Self:
"""Apply a function to all children. Returns a new object.
Parameters
----------
function
The function to apply to all fields. None is interpreted as a no-op.
recurse
If True, the function will be applied to all children that are MoveDataMixin instances.
"""
new = self.clone().apply_(function, recurse=recurse)
return new

def apply_(
self: Self,
function: Callable[[Any], Any] | None = None,
Expand Down
13 changes: 13 additions & 0 deletions src/mrpro/data/SpatialDimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
VectorTypes = torch.Tensor
ScalarTypes = int | float
T = TypeVar('T', torch.Tensor, int, float)

# Covariant types, as SpatialDimension is a Container
# and we want, for example, SpatialDimension[int] to also be a SpatialDimension[float]
T_co = TypeVar('T_co', torch.Tensor, int, float, covariant=True)
Expand Down Expand Up @@ -108,6 +109,7 @@ def from_array_zyx(

return SpatialDimension(z, y, x)

# This function is mainly for type hinting and docstring
def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self:
"""Apply a function to each z, y, x (in-place).
Expand All @@ -118,6 +120,17 @@ def apply_(self, function: Callable[[T], T] | None = None, **_) -> Self:
"""
return super(SpatialDimension, self).apply_(function)

# This function is mainly for type hinting and docstring
def apply(self, function: Callable[[T], T] | None = None, **_) -> Self:
"""Apply a function to each z, y, x (returning a new object).
Parameters
----------
function
function to apply
"""
return super(SpatialDimension, self).apply(function)

@property
def zyx(self) -> tuple[T_co, T_co, T_co]:
"""Return a z,y,x tuple."""
Expand Down
15 changes: 10 additions & 5 deletions src/mrpro/data/traj_calculators/KTrajectoryPulseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,18 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape:
raise ValueError('We currently only support constant number of samples')
n_k0 = int(n_samples.item())

def reshape_pulseq_traj(k_traj: torch.Tensor, encoding_size: int):
k_traj *= encoding_size / (2 * torch.max(torch.abs(k_traj)))
def rescale_and_reshape_traj(k_traj: torch.Tensor, encoding_size: int):
if encoding_size > 1 and torch.max(torch.abs(k_traj)) > 0:
k_traj = k_traj * encoding_size / (2 * torch.max(torch.abs(k_traj)))
else:
# We force k_traj to be 0 if encoding_size = 1. This is typically the case for kz in 2D sequences.
# However, it happens that seq.calculate_kspace() returns values != 0 (numerical noise) in such cases.
k_traj = torch.zeros_like(k_traj)
return rearrange(k_traj, '(other k0) -> other k0', k0=n_k0)

# rearrange k-space trajectory to match MRpro convention
kx = reshape_pulseq_traj(k_traj_adc[0], kheader.encoding_matrix.x)
ky = reshape_pulseq_traj(k_traj_adc[1], kheader.encoding_matrix.y)
kz = reshape_pulseq_traj(k_traj_adc[2], kheader.encoding_matrix.z)
kx = rescale_and_reshape_traj(k_traj_adc[0], kheader.encoding_matrix.x)
ky = rescale_and_reshape_traj(k_traj_adc[1], kheader.encoding_matrix.y)
kz = rescale_and_reshape_traj(k_traj_adc[2], kheader.encoding_matrix.z)

return KTrajectoryRawShape(kz, ky, kx, self.repeat_detection_tolerance)
6 changes: 3 additions & 3 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
# Broadcast shapes not always needed but also does not hurt
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction

numpoints = [min(img_size, nufft_numpoints) for img_size in self._nufft_im_size]
self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
numpoints=numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
numpoints=numpoints,
kbwidth=nufft_kbwidth,
)
else:
Expand Down
7 changes: 1 addition & 6 deletions src/mrpro/operators/LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,4 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
@property
def H(self) -> LinearOperator: # noqa: N802
"""Adjoint of adjoint operator, i.e. original LinearOperator."""
return self.operator

@property
def gram(self) -> LinearOperator:
"""Gram operator."""
return self._operator.gram.H
return self._operator
7 changes: 5 additions & 2 deletions src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import mrpro.utils.slice_profiles
import mrpro.utils.typing
import mrpro.utils.unit_conversion
from mrpro.utils.fill_range import fill_range_
from mrpro.utils.smap import smap
from mrpro.utils.remove_repeat import remove_repeat
from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop
from mrpro.utils.split_idx import split_idx
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted
import mrpro.utils.unit_conversion

__all__ = [
"broadcast_right",
"fill_range_",
"reduce_view",
"remove_repeat",
"reshape_broadcasted",
"slice_profiles",
"smap",
"split_idx",
Expand All @@ -20,4 +23,4 @@
"unsqueeze_left",
"unsqueeze_right",
"zero_pad_or_crop"
]
]
24 changes: 24 additions & 0 deletions src/mrpro/utils/fill_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Fill tensor in-place along a specified dimension with increasing integers."""

import torch


def fill_range_(tensor: torch.Tensor, dim: int) -> None:
"""
Fill tensor in-place along a specified dimension with increasing integers.
Parameters
----------
tensor
The tensor to be modified in-place.
dim
The dimension along which to fill with increasing values.
"""
if not -tensor.ndim <= dim < tensor.ndim:
raise IndexError(f'Dimension {dim} is out of range for tensor with {tensor.ndim} dimensions.')

dim = dim % tensor.ndim
shape = [s if d == dim else 1 for d, s in enumerate(tensor.shape)]
values = torch.arange(tensor.size(dim), device=tensor.device).reshape(shape)
tensor[:] = values.expand_as(tensor)
101 changes: 101 additions & 0 deletions src/mrpro/utils/reshape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Tensor reshaping utilities."""

from collections.abc import Sequence
from functools import lru_cache
from math import prod

import torch

Expand Down Expand Up @@ -99,3 +101,102 @@ def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torc
for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True))
]
return torch.as_strided(x, newsize, stride)


@lru_cache
def _reshape_idx(old_shape: tuple[int, ...], new_shape: tuple[int, ...], old_stride: tuple[int, ...]) -> list[slice]:
"""Get reshape reduce index (Cached helper function for reshape_broadcasted).
This function tries to group axes from new_shape and old_shape into the smallest groups that have
the same number of elements, starting from the right.
If all axes of old shape of a group are stride=0 dimensions, we can reduce them.
Example:
old_shape = (30, 2, 2, 3)
new_shape = (6, 5, 4, 3)
Will results in the groups (starting from the right):
- old: 3 new: 3
- old: 2, 2 new: 4
- old: 30 new: 6, 5
Only the "old" groups are important.
If all axes that are grouped together in an "old" group are stride 0 (=broadcasted)
we can collapse them to singleton dimensions.
This function returns the indexer that either collapses dimensions to singleton or keeps all
elements, i.e. the slices in the returned list are all either slice(1) or slice(None).
"""
idx = []
pointer_old, pointer_new = len(old_shape) - 1, len(new_shape) - 1 # start from the right
while pointer_old >= 0:
product_new, product_old = 1, 1 # the number of elements in the current "new" and "old" group
group: list[int] = []
while product_old != product_new or not group:
if product_old <= product_new:
# increase "old" group
product_old *= old_shape[pointer_old]
group.append(pointer_old)
pointer_old -= 1
else:
# increase "new" group
# we don't need to track the new group, the number of elemeents covered.
product_new *= new_shape[pointer_new]
pointer_new -= 1
# we found a group. now we need to decide what to do.
if all(old_stride[d] == 0 for d in group):
# all dimensions are broadcasted
# -> reduce to singleton
idx.extend([slice(1)] * len(group))
else:
# preserve dimension
idx.extend([slice(None)] * len(group))
idx = idx[::-1] # we worked right to left, but our index should be left to right
return idx


def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor:
"""Reshape a tensor while preserving broadcasted (stride 0) dimensions where possible.
Parameters
----------
tensor
The input tensor to reshape.
shape
The target shape for the tensor. One of the values can be `-1` and its size will be inferred.
Returns
-------
A tensor reshaped to the target shape, preserving broadcasted dimensions where feasible.
"""
try:
# if we can view the tensor directly, it will preserve broadcasting
return tensor.view(shape)
except RuntimeError:
# we cannot do a view, we need to do more work:

# -1 means infer size, i.e. the remaining elements of the input not already covered by the other axes.
negative_ones = shape.count(-1)
size = tensor.shape.numel()
if not negative_ones:
if prod(shape) != size:
# use same exception as pytorch
raise RuntimeError(f"shape '{list(shape)}' is invalid for input of size {size}") from None
elif negative_ones > 1:
raise RuntimeError('only one dimension can be inferred') from None
elif negative_ones == 1:
# we need to figure out the size of the "-1" dimension
known_size = -prod(shape) # negative, is it includes the -1
if size % known_size:
# non integer result. no possible size of the -1 axis exists.
raise RuntimeError(f"shape '{list(shape)}' is invalid for input of size {size}") from None
shape = tuple(size // known_size if s == -1 else s for s in shape)

# most of the broadcasted dimensions can be preserved: only dimensions that are joined with non
# broadcasted dimensions can not be preserved and must be made contiguous.
# all dimensions that can be preserved as broadcasted are first collapsed to singleton,
# such that contiguous does not create copies along these axes.
idx = _reshape_idx(tensor.shape, shape, tensor.stride())
# make contiguous only in dimensions in which broadcasting cannot be preserved
semicontiguous = tensor[idx].contiguous()
# finally, we can expand the broadcasted dimensions to the requested shape
semicontiguous = semicontiguous.expand(tensor.shape)
return semicontiguous.view(shape)
Loading

0 comments on commit 821dbe6

Please sign in to comment.