Skip to content

Commit

Permalink
Merge branch 'main' into cd_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
schuenke committed Oct 9, 2024
2 parents 5570fd7 + 3616027 commit f7ef075
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 65 deletions.
25 changes: 25 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ jobs:
container:
image: ghcr.io/ptb-mr/mrpro_py311:latest
options: --user runner
permissions:
pull-requests: write
contents: write
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down Expand Up @@ -192,11 +195,33 @@ jobs:
if: github.event_name != 'pull_request'

- name: Save Documentation
id: save_docu
uses: actions/upload-artifact@v4
with:
name: Documentation
path: docs/build/html/

- run: echo 'Artifact url ${{ steps.save_docu.outputs.artifact-url }}'

- run: echo 'Event number ${{ github.event.number }}'

- run: echo 'Event name ${{github.event_name}}'

- name: Update PR with link to summary
if: github.event_name == 'pull_request'
uses: edumserrano/find-create-or-update-comment@v3
with:
issue-number: ${{ github.event.pull_request.number }}
body-includes: '<!-- documentation build ${{ github.event.number }} -->'
comment-author: 'github-actions[bot]'
body: |
<!-- documentation build ${{ github.event.number }} -->
### :books: Documentation
:file_folder: [Download as zip](${{ steps.save_docu.outputs.artifact-url }})
:mag: [View online](https://zimf.de/zipserve/${{ steps.save_docu.outputs.artifact-url }}/)
edit-mode: replace


concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}

Expand Down
21 changes: 21 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@ __pycache__/
# C extensions
*.so

# Usual suspects regarding large file size
*.h5
*.dat
*.mp4
*.mrd
*.npy
*.npz
*.pt
*.ptc
*.pwf
*.pth
*.pdf
*.png
*.gif
*.nii
*.dcm
*.IMA
*.tar
*.zip
*.gz

# Distribution / packaging
.Python
build/
Expand Down
34 changes: 14 additions & 20 deletions src/mrpro/utils/Rotation.py → src/mrpro/data/Rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,15 @@
import re
import warnings
from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, Self, overload
from typing import Literal, Self, overload

import numpy as np
import torch
from scipy._lib._util import check_random_state
from scipy.spatial.transform import Rotation as Rotation_scipy

from mrpro.data.SpatialDimension import SpatialDimension

if TYPE_CHECKING:
from types import EllipsisType
from typing import TYPE_CHECKING, SupportsIndex, TypeAlias

from torch._C import _NestedSequence

# This matches the torch.Tensor indexer typehint
_IndexerTypeInner: TypeAlias = None | bool | int | slice | EllipsisType | torch.Tensor
_SingleIndexerType: TypeAlias = SupportsIndex | _IndexerTypeInner | _NestedSequence[_IndexerTypeInner]
IndexerType: TypeAlias = tuple[_SingleIndexerType, ...] | _SingleIndexerType
from mrpro.utils.typing import IndexerType, NestedSequence

AXIS_ORDER = 'zyx' # This can be modified
QUAT_AXIS_ORDER = AXIS_ORDER + 'w' # Do not modify
Expand Down Expand Up @@ -252,7 +242,11 @@ def _quaternion_to_euler(quaternion: torch.Tensor, seq: str, extrinsic: bool):


class Rotation(torch.nn.Module):
"""A pytorch implementation of scipy.spatial.transform.Rotation.
"""A container for Rotations.
A pytorch implementation of scipy.spatial.transform.Rotation.
For more information see the scipy documentation:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.html
Differences compared to scipy.spatial.transform.Rotation:
Expand All @@ -262,7 +256,7 @@ class Rotation(torch.nn.Module):
- arbitrary number of batching dimensions
"""

def __init__(self, quaternions: torch.Tensor | _NestedSequence[float], normalize: bool = True, copy: bool = True):
def __init__(self, quaternions: torch.Tensor | NestedSequence[float], normalize: bool = True, copy: bool = True):
"""Initialize a new Rotation.
Instead of calling this method, also consider the different ``from_*`` class methods to construct a Rotation.
Expand Down Expand Up @@ -311,7 +305,7 @@ def single(self) -> bool:
return self._single

@classmethod
def from_quat(cls, quaternions: torch.Tensor | _NestedSequence[float]) -> Self:
def from_quat(cls, quaternions: torch.Tensor | NestedSequence[float]) -> Self:
"""Initialize from quaternions.
3D rotations can be represented using unit-norm quaternions [QUAa]_.
Expand All @@ -338,7 +332,7 @@ def from_quat(cls, quaternions: torch.Tensor | _NestedSequence[float]) -> Self:
return cls(quaternions, normalize=True)

@classmethod
def from_matrix(cls, matrix: torch.Tensor | _NestedSequence[float]) -> Self:
def from_matrix(cls, matrix: torch.Tensor | NestedSequence[float]) -> Self:
"""Initialize from rotation matrix.
Rotations in 3 dimensions can be represented with 3 x 3 proper
Expand Down Expand Up @@ -376,7 +370,7 @@ def from_matrix(cls, matrix: torch.Tensor | _NestedSequence[float]) -> Self:
return cls(quaternions, normalize=True, copy=False)

@classmethod
def from_rotvec(cls, rotvec: torch.Tensor | _NestedSequence[float], degrees: bool = False) -> Self:
def from_rotvec(cls, rotvec: torch.Tensor | NestedSequence[float], degrees: bool = False) -> Self:
"""Initialize from rotation vector.
A rotation vector is a 3 dimensional vector which is co-directional to the
Expand Down Expand Up @@ -410,7 +404,7 @@ def from_rotvec(cls, rotvec: torch.Tensor | _NestedSequence[float], degrees: boo
return cls(quaternions, normalize=False, copy=False)

@classmethod
def from_euler(cls, seq: str, angles: torch.Tensor | _NestedSequence[float] | float, degrees: bool = False) -> Self:
def from_euler(cls, seq: str, angles: torch.Tensor | NestedSequence[float] | float, degrees: bool = False) -> Self:
"""Initialize from Euler angles.
Rotations in 3-D can be represented by a sequence of 3
Expand Down Expand Up @@ -689,7 +683,7 @@ def concatenate(cls, rotations: Sequence[Rotation]) -> Self:

def forward(
self,
vectors: _NestedSequence[float] | torch.Tensor | SpatialDimension[torch.Tensor] | SpatialDimension[float],
vectors: NestedSequence[float] | torch.Tensor | SpatialDimension[torch.Tensor] | SpatialDimension[float],
inverse: bool = False,
) -> torch.Tensor | SpatialDimension[torch.Tensor]:
"""Apply this rotation to a set of vectors.
Expand Down Expand Up @@ -1203,7 +1197,7 @@ def __repr__(self):

def mean(
self,
weights: torch.Tensor | _NestedSequence[float] | None = None,
weights: torch.Tensor | NestedSequence[float] | None = None,
dim: None | int | Sequence[int] = None,
keepdim: bool = False,
) -> Self:
Expand Down
1 change: 1 addition & 0 deletions src/mrpro/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
from mrpro.data.MoveDataMixin import MoveDataMixin
from mrpro.data.QData import QData
from mrpro.data.QHeader import QHeader
from mrpro.data.Rotation import Rotation
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.data.TrajectoryDescription import TrajectoryDescription
76 changes: 49 additions & 27 deletions src/mrpro/operators/GridSamplingOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
from mrpro.operators.LinearOperator import LinearOperator


class _AdjointGridSampleCtx(torch.autograd.function.FunctionCtx):
"""Context for Adjoint Grid Sample, used for type hinting."""

shape: Sequence[int]
interpolation_mode: int
padding_mode: int
align_corners: bool
xshape: Sequence[int]
backward_2d_or_3d: Callable
saved_tensors: Sequence[torch.Tensor]
needs_input_grad: Sequence[bool]


class AdjointGridSample(torch.autograd.Function):
"""Autograd Function for Adjoint Grid Sample.
Expand All @@ -20,14 +33,13 @@ class AdjointGridSample(torch.autograd.Function):

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
y: torch.Tensor,
grid: torch.Tensor,
xshape: Sequence[int],
interpolation_mode: Literal['bilinear', 'nearest', 'bicubic'] = 'bilinear',
padding_mode: Literal['zeros', 'border', 'reflection'] = 'zeros',
align_corners: bool = True,
) -> torch.Tensor:
) -> tuple[torch.Tensor, tuple[int, int, Callable]]:
"""Adjoint of the linear operator x->gridsample(x,grid).
Parameters
Expand Down Expand Up @@ -84,18 +96,6 @@ def forward(
raise ValueError(f'xshape and y must have same number of channels, got {xshape[1]} and {y.shape[1]}.')
if len(xshape) - 2 != dim:
raise ValueError(f'len(xshape) and dim must either both bei 2 or 3, got {len(xshape)} and {dim}')

# These are required in the backward.
ctx.xshape = xshape # type: ignore[attr-defined]
ctx.interpolation_mode = mode_enum # type: ignore[attr-defined]
ctx.padding_mode = padding_mode_enum # type: ignore[attr-defined]
ctx.align_corners = align_corners # type: ignore[attr-defined]
ctx.backward_2d_or_3d = backward_2d_or_3d # type: ignore[attr-defined]
if grid.requires_grad:
# only if we need to calculate the gradient for grid we need y
ctx.save_for_backward(grid, y)
else:
ctx.save_for_backward(grid)
dummy = torch.empty(1, dtype=y.dtype, device=y.device).broadcast_to(xshape)
x = backward_2d_or_3d(
y,
Expand All @@ -106,38 +106,60 @@ def forward(
align_corners=align_corners,
output_mask=[True, False],
)[0]
return x

return x, (mode_enum, padding_mode_enum, backward_2d_or_3d)

@staticmethod
def setup_context(
ctx: _AdjointGridSampleCtx,
inputs: tuple[torch.Tensor, torch.Tensor, Sequence[int], str, str, bool],
outputs: tuple[torch.Tensor, tuple[int, int, Callable]],
) -> None:
"""Save information for backward pass."""
y, grid, xshape, _, _, align_corners = inputs
_, (mode_enum, padding_mode_enum, backward_2d_or_3d) = outputs
ctx.xshape = xshape
ctx.interpolation_mode = mode_enum
ctx.padding_mode = padding_mode_enum
ctx.align_corners = align_corners
ctx.backward_2d_or_3d = backward_2d_or_3d

if ctx.needs_input_grad[1]:
# only if we need to calculate the gradient for grid we need y
ctx.save_for_backward(grid, y)
else:
ctx.save_for_backward(grid)

@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx, *grad_output: torch.Tensor
ctx: _AdjointGridSampleCtx, *grad_output: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None, None, None, None, None]:
"""Backward of the Adjoint Gridsample Operator."""
need_y_grad, need_grid_grad, *_ = ctx.needs_input_grad # type: ignore[attr-defined]
grid = ctx.saved_tensors[0] # type: ignore[attr-defined]
need_y_grad, need_grid_grad, *_ = ctx.needs_input_grad
grid = ctx.saved_tensors[0]

if need_y_grad:
# torch.grid_sampler has the same signature as the backward
# (and is used inside F.grid_sample)
grad_y = torch.grid_sampler(
grad_output[0],
grid,
ctx.interpolation_mode, # type: ignore[attr-defined]
ctx.padding_mode, # type: ignore[attr-defined]
ctx.align_corners, # type: ignore[attr-defined]
ctx.interpolation_mode,
ctx.padding_mode,
ctx.align_corners,
)
else:
grad_y = None

if need_grid_grad:
y = ctx.saved_tensors[1] # type: ignore[attr-defined]
grad_grid = ctx.backward_2d_or_3d( # type: ignore[attr-defined]
y = ctx.saved_tensors[1]
grad_grid = ctx.backward_2d_or_3d(
y,
grad_output[0],
grid,
interpolation_mode=ctx.interpolation_mode, # type: ignore[attr-defined]
padding_mode=ctx.padding_mode, # type: ignore[attr-defined]
align_corners=ctx.align_corners, # type: ignore[attr-defined]
interpolation_mode=ctx.interpolation_mode,
padding_mode=ctx.padding_mode,
align_corners=ctx.align_corners,
output_mask=[False, True],
)[1]
else:
Expand Down Expand Up @@ -299,7 +321,7 @@ def _adjoint_implementation(
self.interpolation_mode,
self.padding_mode,
self.align_corners,
)
)[0]
return sampled

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor]:
Expand Down
35 changes: 25 additions & 10 deletions src/mrpro/operators/SliceProjectionOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
from numpy._typing import _NestedSequence as NestedSequence
from torch import Tensor

from mrpro.data.Rotation import Rotation
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.utils.Rotation import Rotation
from mrpro.utils.slice_profiles import SliceSmoothedRectangular


class _MatrixMultiplicationCtx(torch.autograd.function.FunctionCtx):
"""Autograd context for matrix multiplication, used for type hinting."""

x_is_complex: bool
saved_tensors: tuple[Tensor]


class _MatrixMultiplication(torch.autograd.Function):
"""Helper for matrix multiplication.
Expand All @@ -27,9 +34,7 @@ class _MatrixMultiplication(torch.autograd.Function):
"""

@staticmethod
def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor, matrix: Tensor, matrix_adjoint: Tensor) -> Tensor:
ctx.save_for_backward(matrix_adjoint)
ctx.x_is_complex = x.is_complex() # type: ignore[attr-defined]
def forward(x: Tensor, matrix: Tensor, matrix_adjoint: Tensor) -> Tensor: # noqa: ARG004
if x.is_complex() == matrix.is_complex():
return matrix @ x
# required for sparse matrices to support mixed complex/real multiplication
Expand All @@ -39,9 +44,19 @@ def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor, matrix: Tensor,
return torch.complex(matrix.real @ x, matrix.imag @ x)

@staticmethod
def backward(ctx: torch.autograd.function.FunctionCtx, *grad_output: Tensor) -> tuple[Tensor, None, None]:
(matrix_adjoint,) = ctx.saved_tensors # type: ignore[attr-defined]
if ctx.x_is_complex: # type: ignore[attr-defined]
def setup_context(
ctx: _MatrixMultiplicationCtx,
inputs: tuple[Tensor, Tensor, Tensor],
outputs: tuple[Tensor], # noqa: ARG004
) -> None:
x, _, matrix_adjoint = inputs
ctx.x_is_complex = x.is_complex()
ctx.save_for_backward(matrix_adjoint)

@staticmethod
def backward(ctx: _MatrixMultiplicationCtx, *grad_output: Tensor) -> tuple[Tensor, None, None]:
(matrix_adjoint,) = ctx.saved_tensors
if ctx.x_is_complex:
if matrix_adjoint.is_complex() == grad_output[0].is_complex():
grad_x = matrix_adjoint @ grad_output[0]
elif matrix_adjoint.is_complex():
Expand Down Expand Up @@ -219,9 +234,9 @@ def forward(self, x: Tensor) -> tuple[Tensor]:

# For the (unusual case) of batched volumes, we will apply for each element in series
xflat = torch.atleast_2d(einops.rearrange(x, '... x y z -> (...) (x y z)'))
y = torch.stack(
[_MatrixMultiplication.apply(x, matrix, matrix_adjoint).reshape(self._range_shape) for x in xflat], -4
)
yl = [_MatrixMultiplication.apply(x, matrix, matrix_adjoint) for x in xflat]

y = torch.stack([el.reshape(self._range_shape) for el in yl], -4)
y = y.reshape(*y.shape[:-4], *x.shape[:-3], *y.shape[-3:])
return (y,)

Expand Down
4 changes: 2 additions & 2 deletions src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import mrpro.utils.slice_profiles
import mrpro.utils.typing
from mrpro.utils.smap import smap
from mrpro.utils.remove_repeat import remove_repeat
from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop
from mrpro.utils.modify_acq_info import modify_acq_info
from mrpro.utils.split_idx import split_idx
from mrpro.utils.Rotation import Rotation
import mrpro.utils.slice_profiles
Loading

0 comments on commit f7ef075

Please sign in to comment.