From 91e8a1b6c1076cfd93af17265372a52696973be2 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 30 Sep 2024 15:41:46 +0200 Subject: [PATCH 1/6] Move SliceProjectionOp to new autograd.Function syntax (#409) --- src/mrpro/operators/SliceProjectionOp.py | 33 +++++++++++++++++------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/mrpro/operators/SliceProjectionOp.py b/src/mrpro/operators/SliceProjectionOp.py index e58239c8..1233e86a 100644 --- a/src/mrpro/operators/SliceProjectionOp.py +++ b/src/mrpro/operators/SliceProjectionOp.py @@ -17,6 +17,13 @@ from mrpro.utils.slice_profiles import SliceSmoothedRectangular +class _MatrixMultiplicationCtx(torch.autograd.function.FunctionCtx): + """Autograd context for matrix multiplication, used for type hinting.""" + + x_is_complex: bool + saved_tensors: tuple[Tensor] + + class _MatrixMultiplication(torch.autograd.Function): """Helper for matrix multiplication. @@ -27,9 +34,7 @@ class _MatrixMultiplication(torch.autograd.Function): """ @staticmethod - def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor, matrix: Tensor, matrix_adjoint: Tensor) -> Tensor: - ctx.save_for_backward(matrix_adjoint) - ctx.x_is_complex = x.is_complex() # type: ignore[attr-defined] + def forward(x: Tensor, matrix: Tensor, matrix_adjoint: Tensor) -> Tensor: # noqa: ARG004 if x.is_complex() == matrix.is_complex(): return matrix @ x # required for sparse matrices to support mixed complex/real multiplication @@ -39,9 +44,19 @@ def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor, matrix: Tensor, return torch.complex(matrix.real @ x, matrix.imag @ x) @staticmethod - def backward(ctx: torch.autograd.function.FunctionCtx, *grad_output: Tensor) -> tuple[Tensor, None, None]: - (matrix_adjoint,) = ctx.saved_tensors # type: ignore[attr-defined] - if ctx.x_is_complex: # type: ignore[attr-defined] + def setup_context( + ctx: _MatrixMultiplicationCtx, + inputs: tuple[Tensor, Tensor, Tensor], + outputs: tuple[Tensor], # noqa: ARG004 + ) -> None: + x, _, matrix_adjoint = inputs + ctx.x_is_complex = x.is_complex() + ctx.save_for_backward(matrix_adjoint) + + @staticmethod + def backward(ctx: _MatrixMultiplicationCtx, *grad_output: Tensor) -> tuple[Tensor, None, None]: + (matrix_adjoint,) = ctx.saved_tensors + if ctx.x_is_complex: if matrix_adjoint.is_complex() == grad_output[0].is_complex(): grad_x = matrix_adjoint @ grad_output[0] elif matrix_adjoint.is_complex(): @@ -219,9 +234,9 @@ def forward(self, x: Tensor) -> tuple[Tensor]: # For the (unusual case) of batched volumes, we will apply for each element in series xflat = torch.atleast_2d(einops.rearrange(x, '... x y z -> (...) (x y z)')) - y = torch.stack( - [_MatrixMultiplication.apply(x, matrix, matrix_adjoint).reshape(self._range_shape) for x in xflat], -4 - ) + yl = [_MatrixMultiplication.apply(x, matrix, matrix_adjoint) for x in xflat] + + y = torch.stack([el.reshape(self._range_shape) for el in yl], -4) y = y.reshape(*y.shape[:-4], *x.shape[:-3], *y.shape[-3:]) return (y,) From 4ff8cf9f13fe23582bdb759b6d7e90660c70b655 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 30 Sep 2024 19:08:35 +0200 Subject: [PATCH 2/6] Move Rotation to data and create utils.typing (#418) --- src/mrpro/{utils => data}/Rotation.py | 34 +++++++++------------ src/mrpro/data/__init__.py | 1 + src/mrpro/operators/SliceProjectionOp.py | 2 +- src/mrpro/utils/__init__.py | 4 +-- src/mrpro/utils/modify_acq_info.py | 6 +++- src/mrpro/utils/typing.py | 18 +++++++++++ tests/operators/test_slice_projection_op.py | 3 +- tests/utils/test_rotation.py | 5 ++- 8 files changed, 44 insertions(+), 29 deletions(-) rename src/mrpro/{utils => data}/Rotation.py (97%) create mode 100644 src/mrpro/utils/typing.py diff --git a/src/mrpro/utils/Rotation.py b/src/mrpro/data/Rotation.py similarity index 97% rename from src/mrpro/utils/Rotation.py rename to src/mrpro/data/Rotation.py index caa48e69..a0cd16f6 100644 --- a/src/mrpro/utils/Rotation.py +++ b/src/mrpro/data/Rotation.py @@ -45,7 +45,7 @@ import re import warnings from collections.abc import Sequence -from typing import TYPE_CHECKING, Literal, Self, overload +from typing import Literal, Self, overload import numpy as np import torch @@ -53,17 +53,7 @@ from scipy.spatial.transform import Rotation as Rotation_scipy from mrpro.data.SpatialDimension import SpatialDimension - -if TYPE_CHECKING: - from types import EllipsisType - from typing import TYPE_CHECKING, SupportsIndex, TypeAlias - - from torch._C import _NestedSequence - - # This matches the torch.Tensor indexer typehint - _IndexerTypeInner: TypeAlias = None | bool | int | slice | EllipsisType | torch.Tensor - _SingleIndexerType: TypeAlias = SupportsIndex | _IndexerTypeInner | _NestedSequence[_IndexerTypeInner] - IndexerType: TypeAlias = tuple[_SingleIndexerType, ...] | _SingleIndexerType +from mrpro.utils.typing import IndexerType, NestedSequence AXIS_ORDER = 'zyx' # This can be modified QUAT_AXIS_ORDER = AXIS_ORDER + 'w' # Do not modify @@ -252,7 +242,11 @@ def _quaternion_to_euler(quaternion: torch.Tensor, seq: str, extrinsic: bool): class Rotation(torch.nn.Module): - """A pytorch implementation of scipy.spatial.transform.Rotation. + """A container for Rotations. + + A pytorch implementation of scipy.spatial.transform.Rotation. + For more information see the scipy documentation: + https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.html Differences compared to scipy.spatial.transform.Rotation: @@ -262,7 +256,7 @@ class Rotation(torch.nn.Module): - arbitrary number of batching dimensions """ - def __init__(self, quaternions: torch.Tensor | _NestedSequence[float], normalize: bool = True, copy: bool = True): + def __init__(self, quaternions: torch.Tensor | NestedSequence[float], normalize: bool = True, copy: bool = True): """Initialize a new Rotation. Instead of calling this method, also consider the different ``from_*`` class methods to construct a Rotation. @@ -311,7 +305,7 @@ def single(self) -> bool: return self._single @classmethod - def from_quat(cls, quaternions: torch.Tensor | _NestedSequence[float]) -> Self: + def from_quat(cls, quaternions: torch.Tensor | NestedSequence[float]) -> Self: """Initialize from quaternions. 3D rotations can be represented using unit-norm quaternions [QUAa]_. @@ -338,7 +332,7 @@ def from_quat(cls, quaternions: torch.Tensor | _NestedSequence[float]) -> Self: return cls(quaternions, normalize=True) @classmethod - def from_matrix(cls, matrix: torch.Tensor | _NestedSequence[float]) -> Self: + def from_matrix(cls, matrix: torch.Tensor | NestedSequence[float]) -> Self: """Initialize from rotation matrix. Rotations in 3 dimensions can be represented with 3 x 3 proper @@ -376,7 +370,7 @@ def from_matrix(cls, matrix: torch.Tensor | _NestedSequence[float]) -> Self: return cls(quaternions, normalize=True, copy=False) @classmethod - def from_rotvec(cls, rotvec: torch.Tensor | _NestedSequence[float], degrees: bool = False) -> Self: + def from_rotvec(cls, rotvec: torch.Tensor | NestedSequence[float], degrees: bool = False) -> Self: """Initialize from rotation vector. A rotation vector is a 3 dimensional vector which is co-directional to the @@ -410,7 +404,7 @@ def from_rotvec(cls, rotvec: torch.Tensor | _NestedSequence[float], degrees: boo return cls(quaternions, normalize=False, copy=False) @classmethod - def from_euler(cls, seq: str, angles: torch.Tensor | _NestedSequence[float] | float, degrees: bool = False) -> Self: + def from_euler(cls, seq: str, angles: torch.Tensor | NestedSequence[float] | float, degrees: bool = False) -> Self: """Initialize from Euler angles. Rotations in 3-D can be represented by a sequence of 3 @@ -689,7 +683,7 @@ def concatenate(cls, rotations: Sequence[Rotation]) -> Self: def forward( self, - vectors: _NestedSequence[float] | torch.Tensor | SpatialDimension[torch.Tensor] | SpatialDimension[float], + vectors: NestedSequence[float] | torch.Tensor | SpatialDimension[torch.Tensor] | SpatialDimension[float], inverse: bool = False, ) -> torch.Tensor | SpatialDimension[torch.Tensor]: """Apply this rotation to a set of vectors. @@ -1203,7 +1197,7 @@ def __repr__(self): def mean( self, - weights: torch.Tensor | _NestedSequence[float] | None = None, + weights: torch.Tensor | NestedSequence[float] | None = None, dim: None | int | Sequence[int] = None, keepdim: bool = False, ) -> Self: diff --git a/src/mrpro/data/__init__.py b/src/mrpro/data/__init__.py index 7e19e278..06ef1d76 100644 --- a/src/mrpro/data/__init__.py +++ b/src/mrpro/data/__init__.py @@ -14,5 +14,6 @@ from mrpro.data.MoveDataMixin import MoveDataMixin from mrpro.data.QData import QData from mrpro.data.QHeader import QHeader +from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension from mrpro.data.TrajectoryDescription import TrajectoryDescription diff --git a/src/mrpro/operators/SliceProjectionOp.py b/src/mrpro/operators/SliceProjectionOp.py index 1233e86a..40f632f9 100644 --- a/src/mrpro/operators/SliceProjectionOp.py +++ b/src/mrpro/operators/SliceProjectionOp.py @@ -11,9 +11,9 @@ from numpy._typing import _NestedSequence as NestedSequence from torch import Tensor +from mrpro.data.Rotation import Rotation from mrpro.data.SpatialDimension import SpatialDimension from mrpro.operators.LinearOperator import LinearOperator -from mrpro.utils.Rotation import Rotation from mrpro.utils.slice_profiles import SliceSmoothedRectangular diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 5c02c58f..08708819 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -1,7 +1,7 @@ +import mrpro.utils.slice_profiles +import mrpro.utils.typing from mrpro.utils.smap import smap from mrpro.utils.remove_repeat import remove_repeat from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop from mrpro.utils.modify_acq_info import modify_acq_info from mrpro.utils.split_idx import split_idx -from mrpro.utils.Rotation import Rotation -import mrpro.utils.slice_profiles diff --git a/src/mrpro/utils/modify_acq_info.py b/src/mrpro/utils/modify_acq_info.py index 5ce5e9c9..d535e53c 100644 --- a/src/mrpro/utils/modify_acq_info.py +++ b/src/mrpro/utils/modify_acq_info.py @@ -1,11 +1,15 @@ """Modify AcqInfo.""" +from __future__ import annotations + import dataclasses from collections.abc import Callable +from typing import TYPE_CHECKING import torch -from mrpro.data.AcqInfo import AcqInfo +if TYPE_CHECKING: + from mrpro.data.AcqInfo import AcqInfo def modify_acq_info(fun_modify: Callable, acq_info: AcqInfo) -> AcqInfo: diff --git a/src/mrpro/utils/typing.py b/src/mrpro/utils/typing.py new file mode 100644 index 00000000..ff829b38 --- /dev/null +++ b/src/mrpro/utils/typing.py @@ -0,0 +1,18 @@ +"""Some type hints that are used in multiple places in the codebase but not part of mrpro's public API.""" + +from typing import TYPE_CHECKING, Any, TypeAlias + +if TYPE_CHECKING: + from types import EllipsisType + from typing import SupportsIndex, TypeAlias + + import torch + from torch._C import _NestedSequence as NestedSequence + + # This matches the torch.Tensor indexer typehint + _IndexerTypeInner: TypeAlias = None | bool | int | slice | EllipsisType | torch.Tensor + _SingleIndexerType: TypeAlias = SupportsIndex | _IndexerTypeInner | NestedSequence[_IndexerTypeInner] + IndexerType: TypeAlias = tuple[_SingleIndexerType, ...] | _SingleIndexerType +else: + IndexerType: TypeAlias = Any + NestedSequence: TypeAlias = Any diff --git a/tests/operators/test_slice_projection_op.py b/tests/operators/test_slice_projection_op.py index a2737152..09a296d4 100644 --- a/tests/operators/test_slice_projection_op.py +++ b/tests/operators/test_slice_projection_op.py @@ -5,9 +5,8 @@ import numpy as np import pytest import torch -from mrpro.data import SpatialDimension +from mrpro.data import Rotation, SpatialDimension from mrpro.operators import SliceProjectionOp -from mrpro.utils import Rotation from mrpro.utils.slice_profiles import SliceGaussian, SliceInterpolate, SliceSmoothedRectangular from tests import RandomGenerator diff --git a/tests/utils/test_rotation.py b/tests/utils/test_rotation.py index bb8777f0..5f9831b6 100644 --- a/tests/utils/test_rotation.py +++ b/tests/utils/test_rotation.py @@ -41,9 +41,8 @@ import numpy as np import pytest import torch -from mrpro.data import SpatialDimension -from mrpro.utils import Rotation -from mrpro.utils.Rotation import AXIS_ORDER +from mrpro.data import Rotation, SpatialDimension +from mrpro.data.Rotation import AXIS_ORDER from scipy.stats import special_ortho_group from tests import RandomGenerator From f11f047014cc6a4bf3b99ffa0c0add77b3eb5958 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 30 Sep 2024 23:57:24 +0200 Subject: [PATCH 3/6] Move GridSamplingOp to new autograd.Function syntax (#408) --- src/mrpro/operators/GridSamplingOp.py | 76 +++++++++++++++++---------- 1 file changed, 49 insertions(+), 27 deletions(-) diff --git a/src/mrpro/operators/GridSamplingOp.py b/src/mrpro/operators/GridSamplingOp.py index 3294f86e..e88cab16 100644 --- a/src/mrpro/operators/GridSamplingOp.py +++ b/src/mrpro/operators/GridSamplingOp.py @@ -11,6 +11,19 @@ from mrpro.operators.LinearOperator import LinearOperator +class _AdjointGridSampleCtx(torch.autograd.function.FunctionCtx): + """Context for Adjoint Grid Sample, used for type hinting.""" + + shape: Sequence[int] + interpolation_mode: int + padding_mode: int + align_corners: bool + xshape: Sequence[int] + backward_2d_or_3d: Callable + saved_tensors: Sequence[torch.Tensor] + needs_input_grad: Sequence[bool] + + class AdjointGridSample(torch.autograd.Function): """Autograd Function for Adjoint Grid Sample. @@ -20,14 +33,13 @@ class AdjointGridSample(torch.autograd.Function): @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, y: torch.Tensor, grid: torch.Tensor, xshape: Sequence[int], interpolation_mode: Literal['bilinear', 'nearest', 'bicubic'] = 'bilinear', padding_mode: Literal['zeros', 'border', 'reflection'] = 'zeros', align_corners: bool = True, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, tuple[int, int, Callable]]: """Adjoint of the linear operator x->gridsample(x,grid). Parameters @@ -84,18 +96,6 @@ def forward( raise ValueError(f'xshape and y must have same number of channels, got {xshape[1]} and {y.shape[1]}.') if len(xshape) - 2 != dim: raise ValueError(f'len(xshape) and dim must either both bei 2 or 3, got {len(xshape)} and {dim}') - - # These are required in the backward. - ctx.xshape = xshape # type: ignore[attr-defined] - ctx.interpolation_mode = mode_enum # type: ignore[attr-defined] - ctx.padding_mode = padding_mode_enum # type: ignore[attr-defined] - ctx.align_corners = align_corners # type: ignore[attr-defined] - ctx.backward_2d_or_3d = backward_2d_or_3d # type: ignore[attr-defined] - if grid.requires_grad: - # only if we need to calculate the gradient for grid we need y - ctx.save_for_backward(grid, y) - else: - ctx.save_for_backward(grid) dummy = torch.empty(1, dtype=y.dtype, device=y.device).broadcast_to(xshape) x = backward_2d_or_3d( y, @@ -106,15 +106,37 @@ def forward( align_corners=align_corners, output_mask=[True, False], )[0] - return x + + return x, (mode_enum, padding_mode_enum, backward_2d_or_3d) + + @staticmethod + def setup_context( + ctx: _AdjointGridSampleCtx, + inputs: tuple[torch.Tensor, torch.Tensor, Sequence[int], str, str, bool], + outputs: tuple[torch.Tensor, tuple[int, int, Callable]], + ) -> None: + """Save information for backward pass.""" + y, grid, xshape, _, _, align_corners = inputs + _, (mode_enum, padding_mode_enum, backward_2d_or_3d) = outputs + ctx.xshape = xshape + ctx.interpolation_mode = mode_enum + ctx.padding_mode = padding_mode_enum + ctx.align_corners = align_corners + ctx.backward_2d_or_3d = backward_2d_or_3d + + if ctx.needs_input_grad[1]: + # only if we need to calculate the gradient for grid we need y + ctx.save_for_backward(grid, y) + else: + ctx.save_for_backward(grid) @staticmethod def backward( - ctx: torch.autograd.function.FunctionCtx, *grad_output: torch.Tensor + ctx: _AdjointGridSampleCtx, *grad_output: torch.Tensor ) -> tuple[torch.Tensor | None, torch.Tensor | None, None, None, None, None]: """Backward of the Adjoint Gridsample Operator.""" - need_y_grad, need_grid_grad, *_ = ctx.needs_input_grad # type: ignore[attr-defined] - grid = ctx.saved_tensors[0] # type: ignore[attr-defined] + need_y_grad, need_grid_grad, *_ = ctx.needs_input_grad + grid = ctx.saved_tensors[0] if need_y_grad: # torch.grid_sampler has the same signature as the backward @@ -122,22 +144,22 @@ def backward( grad_y = torch.grid_sampler( grad_output[0], grid, - ctx.interpolation_mode, # type: ignore[attr-defined] - ctx.padding_mode, # type: ignore[attr-defined] - ctx.align_corners, # type: ignore[attr-defined] + ctx.interpolation_mode, + ctx.padding_mode, + ctx.align_corners, ) else: grad_y = None if need_grid_grad: - y = ctx.saved_tensors[1] # type: ignore[attr-defined] - grad_grid = ctx.backward_2d_or_3d( # type: ignore[attr-defined] + y = ctx.saved_tensors[1] + grad_grid = ctx.backward_2d_or_3d( y, grad_output[0], grid, - interpolation_mode=ctx.interpolation_mode, # type: ignore[attr-defined] - padding_mode=ctx.padding_mode, # type: ignore[attr-defined] - align_corners=ctx.align_corners, # type: ignore[attr-defined] + interpolation_mode=ctx.interpolation_mode, + padding_mode=ctx.padding_mode, + align_corners=ctx.align_corners, output_mask=[False, True], )[1] else: @@ -299,7 +321,7 @@ def _adjoint_implementation( self.interpolation_mode, self.padding_mode, self.align_corners, - ) + )[0] return sampled def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor]: From 763ce964703281f8d621d0ae94257e022acfb460 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 1 Oct 2024 20:23:07 +0200 Subject: [PATCH 4/6] Add link to docu as PR comment (#423) Co-authored-by: Felix F Zimmermann --- .github/workflows/docs.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index b06e925b..3ee3ce7c 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -137,6 +137,9 @@ jobs: container: image: ghcr.io/ptb-mr/mrpro_py311:latest options: --user runner + permissions: + pull-requests: write + contents: write steps: - name: Checkout uses: actions/checkout@v4 @@ -192,11 +195,30 @@ jobs: if: github.event_name != 'pull_request' - name: Save Documentation + id: save_docu uses: actions/upload-artifact@v4 with: name: Documentation path: docs/build/html/ + - run: echo '${{ steps.save_docu.outputs.artifact-url }}' + + - run: echo '${{ github.event.number }}' + + - name: Update PR with link to summary + uses: edumserrano/find-create-or-update-comment@v3 + with: + issue-number: ${{ github.event.pull_request.number }} + body-includes: '' + comment-author: 'github-actions[bot]' + body: | + + ### :books: Documentation + :file_folder: [Download as zip](${{ steps.save_docu.outputs.artifact-url }}) + :mag: [View online](https://zimf.de/zipserve/${{ steps.save_docu.outputs.artifact-url }}/) + edit-mode: replace + + concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} From 4dc01c882958edef3f14b5958390b946a9f59f1e Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 8 Oct 2024 08:00:14 +0200 Subject: [PATCH 5/6] Fix434 - no link to docu for merge with main (#437) --- .github/workflows/docs.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 3ee3ce7c..b043eeb3 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -201,11 +201,14 @@ jobs: name: Documentation path: docs/build/html/ - - run: echo '${{ steps.save_docu.outputs.artifact-url }}' + - run: echo 'Artifact url ${{ steps.save_docu.outputs.artifact-url }}' - - run: echo '${{ github.event.number }}' + - run: echo 'Event number ${{ github.event.number }}' + + - run: echo 'Event name ${{github.event_name}}' - name: Update PR with link to summary + if: github.event_name == 'pull_request' uses: edumserrano/find-create-or-update-comment@v3 with: issue-number: ${{ github.event.pull_request.number }} From 3616027b70a560aa30fd4115a84d823a271a2769 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Tue, 8 Oct 2024 13:41:38 +0200 Subject: [PATCH 6/6] Gitignore common data file types (#424) --- .gitignore | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/.gitignore b/.gitignore index d2c098d7..c1694cad 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,27 @@ __pycache__/ # C extensions *.so +# Usual suspects regarding large file size +*.h5 +*.dat +*.mp4 +*.mrd +*.npy +*.npz +*.pt +*.ptc +*.pwf +*.pth +*.pdf +*.png +*.gif +*.nii +*.dcm +*.IMA +*.tar +*.zip +*.gz + # Distribution / packaging .Python build/