Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 committed Nov 14, 2024
1 parent e77fdc7 commit 6688633
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/mrpro/operators/Jacobian.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Jacobian."""

from collections.abc import Callable
from typing import Unpack

import torch

Expand All @@ -25,7 +26,7 @@ def __init__(self, operator: Operator[torch.Tensor, tuple[torch.Tensor]], *x0: t
point at which to linearize the operator
"""
super().__init__()
self._vjp: Callable[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]] | None = None
self._vjp: Callable[[Unpack[tuple[torch.Tensor, ...]]], tuple[torch.Tensor, ...]] | None = None
self._x0: tuple[torch.Tensor, ...] = x0
self._operator = operator
self._f_x0: tuple[torch.Tensor, ...] | None = None
Expand Down
25 changes: 25 additions & 0 deletions tests/operators/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


def test_jacobian_adjointness():
"""Test adjointness of Jacobian operator."""
rng = RandomGenerator(123)
x = rng.float32_tensor(3)
y = rng.float32_tensor(())
Expand All @@ -17,10 +18,34 @@ def test_jacobian_adjointness():


def test_jacobian_taylor():
"""Test Taylor expansion"""
rng = RandomGenerator(123)
x0 = rng.float32_tensor(3)
x = x0 + 1e-2 * rng.float32_tensor(3)
op = L2NormSquared()
jacobian = Jacobian(op, x0)
fx = jacobian.taylor(x)
torch.testing.assert_close(fx, op(x), rtol=1e-3, atol=1e-3)


def test_jacobian_gaussnewton():
"""Test Gauss Newton approximation of the Hessian"""
rng = RandomGenerator(123)
x0 = rng.float32_tensor(3)
x = x0 + 1e-2 * rng.float32_tensor(3)
op = L2NormSquared()
jacobian = Jacobian(op, x0)
(actual,) = jacobian.gauss_newton(x)
expected = torch.vdot(x, x0) * 4 * x0 # analytical solution for L2NormSquared
torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3)


def test_jacobian_valueatx0():
"""Test value at x0"""
rng = RandomGenerator(123)
x0 = rng.float32_tensor(3)
op = L2NormSquared()
jacobian = Jacobian(op, x0)
(actual,) = jacobian.value_at_x0
(expected,) = op(x0)
torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3)

0 comments on commit 6688633

Please sign in to comment.