Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move SliceProjectionOp to new autograd.Function syntax #409

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/mrpro/operators/SliceProjectionOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
from mrpro.utils.slice_profiles import SliceSmoothedRectangular


class _MatrixMultiplicationCtx(torch.autograd.function.FunctionCtx):
"""Autograd context for matrix multiplication, used for type hinting."""

x_is_complex: bool
saved_tensors: tuple[Tensor]


class _MatrixMultiplication(torch.autograd.Function):
"""Helper for matrix multiplication.

Expand All @@ -27,9 +34,7 @@ class _MatrixMultiplication(torch.autograd.Function):
"""

@staticmethod
def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor, matrix: Tensor, matrix_adjoint: Tensor) -> Tensor:
ctx.save_for_backward(matrix_adjoint)
ctx.x_is_complex = x.is_complex() # type: ignore[attr-defined]
def forward(x: Tensor, matrix: Tensor, matrix_adjoint: Tensor) -> Tensor: # noqa: ARG004
if x.is_complex() == matrix.is_complex():
return matrix @ x
# required for sparse matrices to support mixed complex/real multiplication
Expand All @@ -39,9 +44,19 @@ def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor, matrix: Tensor,
return torch.complex(matrix.real @ x, matrix.imag @ x)

@staticmethod
def backward(ctx: torch.autograd.function.FunctionCtx, *grad_output: Tensor) -> tuple[Tensor, None, None]:
(matrix_adjoint,) = ctx.saved_tensors # type: ignore[attr-defined]
if ctx.x_is_complex: # type: ignore[attr-defined]
def setup_context(
ctx: _MatrixMultiplicationCtx,
inputs: tuple[Tensor, Tensor, Tensor],
outputs: tuple[Tensor], # noqa: ARG004
) -> None:
x, _, matrix_adjoint = inputs
ctx.x_is_complex = x.is_complex()
ctx.save_for_backward(matrix_adjoint)

@staticmethod
def backward(ctx: _MatrixMultiplicationCtx, *grad_output: Tensor) -> tuple[Tensor, None, None]:
(matrix_adjoint,) = ctx.saved_tensors
if ctx.x_is_complex:
if matrix_adjoint.is_complex() == grad_output[0].is_complex():
grad_x = matrix_adjoint @ grad_output[0]
elif matrix_adjoint.is_complex():
Expand Down Expand Up @@ -219,9 +234,9 @@ def forward(self, x: Tensor) -> tuple[Tensor]:

# For the (unusual case) of batched volumes, we will apply for each element in series
xflat = torch.atleast_2d(einops.rearrange(x, '... x y z -> (...) (x y z)'))
y = torch.stack(
[_MatrixMultiplication.apply(x, matrix, matrix_adjoint).reshape(self._range_shape) for x in xflat], -4
)
yl = [_MatrixMultiplication.apply(x, matrix, matrix_adjoint) for x in xflat]

y = torch.stack([el.reshape(self._range_shape) for el in yl], -4)
y = y.reshape(*y.shape[:-4], *x.shape[:-3], *y.shape[-3:])
return (y,)

Expand Down