Skip to content

Commit

Permalink
Add Identity LinearOperator (#390)
Browse files Browse the repository at this point in the history
Add a do-nothing linear operator. 
This might be further extended to allow for multiple inputs (i.e. an endomorph overload).
  • Loading branch information
fzimmermann89 authored Sep 16, 2024
1 parent f64add5 commit 567089a
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 0 deletions.
44 changes: 44 additions & 0 deletions src/mrpro/operators/IdentityOp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Identity Operator."""

import torch

from mrpro.operators.LinearOperator import LinearOperator


class IdentityOp(LinearOperator):
r"""The Identity Operator.
A Linear Operator that returns a single input unchanged.
"""

def __init__(self) -> None:
"""Initialize Identity Operator."""
super().__init__()

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
"""Identity of input.
Parameters
----------
x
input tensor
Returns
-------
the input tensor
"""
return (x,)

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor]:
"""Adjoint Identity.
Parameters
----------
x
input tensor
Returns
-------
the input tensor
"""
return (x,)
1 change: 1 addition & 0 deletions src/mrpro/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mrpro.operators.FiniteDifferenceOp import FiniteDifferenceOp
from mrpro.operators.FourierOp import FourierOp
from mrpro.operators.GridSamplingOp import GridSamplingOp
from mrpro.operators.IdentityOp import IdentityOp
from mrpro.operators.MagnitudeOp import MagnitudeOp
from mrpro.operators.PhaseOp import PhaseOp
from mrpro.operators.SensitivityOp import SensitivityOp
Expand Down
36 changes: 36 additions & 0 deletions tests/operators/test_identity_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Tests for Identity Linear Operator."""

import torch
from mrpro.operators import IdentityOp

from tests import RandomGenerator


def test_identity_op():
"""Test forward identity."""
generator = RandomGenerator(seed=0)
tensor = generator.complex64_tensor(2, 3, 4)
operator = IdentityOp()
torch.testing.assert_close(tensor, *operator(tensor))
assert tensor is operator(tensor)[0]


def test_identity_op_adjoint():
"""Test adjoint identity."""
generator = RandomGenerator(seed=0)
tensor = generator.complex64_tensor(2, 3, 4)
operator = IdentityOp().H
torch.testing.assert_close(tensor, *operator(tensor))
assert tensor is operator(tensor)[0]


def test_identity_op_operatorsyntax():
"""Test Identity@(Identity*alpha) + (beta*Identity.H).H"""
generator = RandomGenerator(seed=0)
tensor = generator.complex64_tensor(2, 3, 4)
alpha = generator.complex64_tensor(2, 3, 4)
beta = generator.complex64_tensor(2, 3, 4)
composition = IdentityOp() @ (IdentityOp() * alpha) + (beta * IdentityOp().H).H
expected = tensor * alpha + tensor * beta.conj()
(actual,) = composition(tensor)
torch.testing.assert_close(actual, expected)

0 comments on commit 567089a

Please sign in to comment.