Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect gradient in toy adaptive ODE #499

Open
lockwo opened this issue Aug 24, 2024 · 4 comments
Open

Incorrect gradient in toy adaptive ODE #499

lockwo opened this issue Aug 24, 2024 · 4 comments
Labels
documentation Improvements or additions to documentation

Comments

@lockwo
Copy link
Contributor

lockwo commented Aug 24, 2024

We are encountering gradients that are incorrect in specific regime. Specifically, we have:

  • A custom solver, where the error estimate depends on a second call to the drift function or times
  • Adaptive stepping

Below is a simplified example. Basically, we just take Euler and do some trivial change for the sake of example (we have a more complicated solver, but have identified the root of the issue to be this here), but crucially it has a y error that depends on a recalculation of the drift function (note that with or without the stop gradients doesn't matter). There doesn't seem to be anything wrong with the PIDController since we also implemented a simple controlled and the same error shows up. If constant stepping is used, the gradients are accurate. Note that our finite difference is stable and we have tried epsilon from 1e-10 to 1e-3 and it shows consistent results. The primal values are correct, but there is a difference in the gradient.

import jax

jax.config.update("jax_enable_x64", True)
import jax
from jax import numpy as jnp
import diffrax
from typing import ClassVar
import equinox as eqx
from equinox.internal import ω


class Test(diffrax.AbstractItoSolver):
    term_structure: ClassVar = diffrax.AbstractTerm
    interpolation_cls: ClassVar = diffrax.LocalLinearInterpolation

    def order(self, terms):
        return 1

    def strong_order(self, terms):
        return 0.5

    def init(
        self,
        terms,
        t0,
        t1,
        y0,
        args,
    ):
        return None

    def func(
        self,
        terms,
        t0,
        y0,
        args,
    ):
        return terms.vf(t0, y0, args)

    def step(
        self,
        terms,
        t0,
        t1,
        y0,
        args,
        solver_state,
        made_jump,
    ):
        del made_jump
        control = terms.contr(t0, t1)
        y1 = (y0**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω

        drift = terms
        b = jax.lax.stop_gradient(drift.vf(t0, y0, args))
        y_error = jax.lax.stop_gradient(jnp.linalg.norm(b) * (t1 - t0))

        dense_info = dict(y0=y0, y1=y1)
        return y1, y_error, dense_info, solver_state, diffrax.RESULTS.successful

t0, t1 = 0.0, 3.0
y0 = jnp.array([1.0, 1.0])
tol = 1e-1
solver = Test()
cont = diffrax.PIDController(tol, tol, error_order=1.0)


def drift(t, X, args):
    y1, y2 = X
    dy1 = -273 / 512 * y1
    dy2 = -1 // 160 * y1 - (-785 // 512 + jnp.sqrt(2) / 8) * y2
    return jnp.array([dy1, dy2])

def solve(key, y0):
    terms = diffrax.ODETerm(drift)
    saveat = diffrax.SaveAt(t1=True)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0=0.0001,
        y0=y0,
        saveat=saveat,
        max_steps=1000,
        stepsize_controller=cont,
        adjoint=diffrax.RecursiveCheckpointAdjoint(),
    )
    return sol

def loss(y):
    k = jax.random.key(0)
    s = solve(k, y)
    return jnp.sqrt(jnp.mean(s.ys ** 2)), s.stats

x0 = jnp.array([1.0, 1.0])
print(eqx.filter_value_and_grad(loss, has_aux=True)(x0))

def finite_diff(y):
    eps = 1e-9
    val1 = loss(jnp.array([y[0] + eps / 2, y[1]]))[0]
    val2 = loss(jnp.array([y[0], y[1] + eps / 2]))[0]
    val3 = loss(jnp.array([y[0] - eps / 2, y[1]]))[0]
    val4 = loss(jnp.array([y[0], y[1] - eps / 2]))[0]
    print(val1, val2, val3, val4)
    return jnp.array([val1 - val3, val2 - val4]) / eps

print(finite_diff(x0))

prints

((Array(81.89217529, dtype=float64),
  {'max_steps': 1000,
   'num_accepted_steps': Array(682, dtype=int64, weak_type=True),
   'num_rejected_steps': Array(44, dtype=int64, weak_type=True),
   'num_steps': Array(726, dtype=int64, weak_type=True)}),
 Array([-60.26947042, 142.16164571], dtype=float64))

81.8921752553513 81.89217537306844 81.89217532865639 81.89217521093927
Array([-73.30508822, 162.12916876], dtype=float64)

We see accurate primal, but inaccurate gradients (by enough that this cannot just be numerical noise, we have tried on an other problems and see larger differences as well). The error order is wrong too, but that shouldn't matter, since we should just converge wrong, not change the differentiability of it. Are we violating some requirement by using drift again? Everything should be differentiable (and we tried anywhere from 0 to many, many stop gradients around all error related terms and couldn't seem to get anything to happen).

@lockwo
Copy link
Contributor Author

lockwo commented Aug 24, 2024

Follow up, I actually just tried with a trivial implementation of Heun and it's also not working. Also, the mean/square stuff has no impact as well, tested without it.

Follow follow up, (diffrax) real Heun doesn't work, that is to say, gradients of Heun and finite difference don't match up. Now I am confused. Finite difference is extremely stable, matches the primal exactly and shows consistent gradients from 1e-2 to 1e-15 and more.

If I decrease the tolerance, I see both matching up. Only at large tolerances do they disagree.

@lockwo
Copy link
Contributor Author

lockwo commented Aug 24, 2024

Looking deeper, I thought it might be a situation like 4.1.2.4 of https://arxiv.org/abs/2406.09699 where AD is numerically wrong (see also https://github.com/ODINN-SciML/DiffEqSensitivity-Review/blob/main/code/SensitivityForwardAD/testgradient_python.py), since jacrev and jacfwd work. However, they argue this is true for any tolerance, whereas I see it only for large tolerances. Maybe the solution is just don't use large tolerances with first order methods? But my confusion is that this should be differential. Also, the paper said it works in Sensitivity in Julia, so we implemented it in Julia and also saw its wrong (which is extra surprising because the finite diff trajectories are basically identical to the reverse diff trajectories.

Julia code + Results
using OrdinaryDiffEq
using FiniteDiff, ForwardDiff, Statistics, Zygote, ReverseDiff, SciMLSensitivity

function f(X, args, t)
    y1, y2 = X
    dy1 = -273 / 512 * y1
    dy2 = -1 / 160 * y1 - (-785 / 512 + sqrt(2) / 8) * y2
    return [dy1, dy2]
end

u0 = [1.0, 1.0]
args = ones(1)
odeprob = ODEProblem(f, u0, (0.0, 3.0), args)

function loss(u0)
    _prob = remake(odeprob, u0=u0)
    _sol = (solve(_prob, Heun(),
        dt=0.1,
        abstol=0.1,
        reltol=0.1,
        save_everystep=true,
        save_start = false,
        #adaptive=false
        controller=IController(), #CustomController(),
        sensealg=ReverseDiffAdjoint()
    ))
    @show (_sol.t)
    _sol = _sol[end]

    return sum(abs2, _sol)
end

function finite_diff(u0)
    eps = 1e-5
    v1 = loss([u0[1] + eps / 2, u0[2]])
    v2 = loss([u0[1] - eps / 2, u0[2]])
    v3 = loss([u0[1], u0[2] + eps / 2])
    v4 = loss([u0[1], u0[2] - eps / 2])

    [v1 - v2, v3 -v4] / (eps)
end

begin
    println("FiniteDiff")
    @show finite_diff(u0)
    #dp1 = FiniteDiff.finite_difference_gradient(loss, u0)
    println("Forward")
    dp2 = ForwardDiff.gradient(loss, u0)
    println("Reverse")
    dp3 = Zygote.gradient(loss, u0)[1]
    @show dp1 dp2 dp3
end
FiniteDiff
_sol.t = [0.1, 0.613929509825019, 1.199780114366179, 1.7608249943840808, 2.28993033847855, 2.797764017021789, 3.0]
_sol.t = [0.1, 0.6139295972860885, 1.1997803015279263, 1.7608252770250412, 2.2899306781168063, 2.7977643468051436, 3.0]
_sol.t = [0.1, 0.6139290287896685, 1.1997791292067386, 1.7608237043621409, 2.289928860088721, 2.797762422397756, 3.0]
_sol.t = [0.1, 0.6139301220531311, 1.1997813360415173, 1.7608266217853026, 2.289932183276105, 2.7977659413541907, 3.0]
finite_diff(u0) = [7.376464645858504, 4823.877383614672]
Forward
_sol.t = [0.1, 0.728028559619219, 1.4979699401846973, 2.275540178765593, 3.0]
Reverse
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
dp1 = [3.4720645693875793, 4831.212995065322]
dp2 = [-11.629850313906191, 3561.5179283293633]
dp3 = [-15.024234205512382, 4588.183943938828]

@lockwo
Copy link
Contributor Author

lockwo commented Aug 25, 2024

There was some good discussion in SciML/SciMLSensitivity.jl#1094. Given that clearly isn't a fault of diffrax (or the Julia sciml ecosystem), the original points in my issue aren't as relevant. But maybe this could be in the docs somewhere? Or just a reference to numerical vs algorithmic accuracy considerations? As someone not super knowledgable on the discrete vs. continuous adjoints, this was a tough nut to crack so I'd like to spare some future person the amount of work we put into this if possible lol.

@patrick-kidger
Copy link
Owner

Ah, you're bumping into the esoteric end of the autodiff literature!

An FAQ entry sounds reasonable.

@patrick-kidger patrick-kidger added the documentation Improvements or additions to documentation label Aug 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants