diff --git a/src/mrpro/operators/IdentityOp.py b/src/mrpro/operators/IdentityOp.py new file mode 100644 index 00000000..79abbc06 --- /dev/null +++ b/src/mrpro/operators/IdentityOp.py @@ -0,0 +1,44 @@ +"""Identity Operator.""" + +import torch + +from mrpro.operators.LinearOperator import LinearOperator + + +class IdentityOp(LinearOperator): + r"""The Identity Operator. + + A Linear Operator that returns a single input unchanged. + """ + + def __init__(self) -> None: + """Initialize Identity Operator.""" + super().__init__() + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]: + """Identity of input. + + Parameters + ---------- + x + input tensor + + Returns + ------- + the input tensor + """ + return (x,) + + def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor]: + """Adjoint Identity. + + Parameters + ---------- + x + input tensor + + Returns + ------- + the input tensor + """ + return (x,) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index 721dabb7..71993bc6 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -10,6 +10,7 @@ from mrpro.operators.FiniteDifferenceOp import FiniteDifferenceOp from mrpro.operators.FourierOp import FourierOp from mrpro.operators.GridSamplingOp import GridSamplingOp +from mrpro.operators.IdentityOp import IdentityOp from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.PhaseOp import PhaseOp from mrpro.operators.SensitivityOp import SensitivityOp diff --git a/tests/operators/test_identity_op.py b/tests/operators/test_identity_op.py new file mode 100644 index 00000000..85c182a9 --- /dev/null +++ b/tests/operators/test_identity_op.py @@ -0,0 +1,36 @@ +"""Tests for Identity Linear Operator.""" + +import torch +from mrpro.operators import IdentityOp + +from tests import RandomGenerator + + +def test_identity_op(): + """Test forward identity.""" + generator = RandomGenerator(seed=0) + tensor = generator.complex64_tensor(2, 3, 4) + operator = IdentityOp() + torch.testing.assert_close(tensor, *operator(tensor)) + assert tensor is operator(tensor)[0] + + +def test_identity_op_adjoint(): + """Test adjoint identity.""" + generator = RandomGenerator(seed=0) + tensor = generator.complex64_tensor(2, 3, 4) + operator = IdentityOp().H + torch.testing.assert_close(tensor, *operator(tensor)) + assert tensor is operator(tensor)[0] + + +def test_identity_op_operatorsyntax(): + """Test Identity@(Identity*alpha) + (beta*Identity.H).H""" + generator = RandomGenerator(seed=0) + tensor = generator.complex64_tensor(2, 3, 4) + alpha = generator.complex64_tensor(2, 3, 4) + beta = generator.complex64_tensor(2, 3, 4) + composition = IdentityOp() @ (IdentityOp() * alpha) + (beta * IdentityOp().H).H + expected = tensor * alpha + tensor * beta.conj() + (actual,) = composition(tensor) + torch.testing.assert_close(actual, expected)