diff --git a/tests/firedrake/adjoint/test_optimisation.py b/tests/firedrake/adjoint/test_optimisation.py index 2189a7882b..d024d93d16 100644 --- a/tests/firedrake/adjoint/test_optimisation.py +++ b/tests/firedrake/adjoint/test_optimisation.py @@ -121,13 +121,8 @@ def _simple_helmholz_model(V, source): return u -@pytest.mark.parametrize( - "riesz_representation", - [None, - "l2", - pytest.param("H1", marks=pytest.mark.xfail(reason="H1 is the wrong norm for this problem"))]) @pytest.mark.skipcomplex -def test_simple_inversion(riesz_representation): +def test_simple_inversion(): """Test inversion of source term in helmholze eqn.""" mesh = UnitIntervalMesh(10) V = FunctionSpace(mesh, "CG", 1) @@ -141,7 +136,7 @@ def test_simple_inversion(riesz_representation): # now rerun annotated model with zero source source = Function(V) - c = Control(source, riesz_map=riesz_representation) + c = Control(source, riesz_map="l2") u = _simple_helmholz_model(V, source) J = assemble(1e6 * (u - u_ref)**2*dx) @@ -151,6 +146,44 @@ def test_simple_inversion(riesz_representation): assert_allclose(x.dat.data, source_ref.dat.data, rtol=1e-2) +@pytest.mark.skipcomplex +def test_simple_inversion_scipy_riesz_map_ignored(): + mesh = UnitIntervalMesh(10) + V = FunctionSpace(mesh, "CG", 1) + source_ref = Function(V) + x = SpatialCoordinate(mesh) + source_ref.interpolate(cos(pi*x**2)) + + # compute reference solution + with stop_annotating(): + u_ref = _simple_helmholz_model(V, source_ref) + + class Counter: + def __init__(self): + self._n = 0 + + def __call__(self, xk): + self._n += 1 + + @property + def n(self): + return self._n + + iterations = {} + for riesz_map in ("l2", "L2", "H1"): + source = Function(V) + u = _simple_helmholz_model(V, source) + J = assemble(1e6 * (u - u_ref)**2*dx) + c = Control(source, riesz_map=riesz_map) + rf = ReducedFunctional(J, c) + cb = Counter() + x = minimize(rf, callback=cb) + assert_allclose(x.dat.data, source_ref.dat.data, rtol=1e-2) + iterations[riesz_map] = cb.n + # Should always take the same number of iterations + assert len(set(iterations.values())) == 1 + + @pytest.mark.parametrize("minimize", [minimize_tao_lmvm, minimize_tao_nls]) @pytest.mark.parametrize("riesz_representation", [None, "l2", "L2", "H1"])