Skip to content

Commit

Permalink
Add regression test
Browse files Browse the repository at this point in the history
  • Loading branch information
samuela committed Apr 24, 2020
1 parent c7d597f commit 04bae57
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions jax/experimental/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,17 @@ def f(y0, ts, *args):
check_grads(f, (y0, ts, *args), modes=["rev"], order=2,
atol=1e-1, rtol=1e-1)

def weird_time_pendulum_check_grads():
"""Test that gradients are correct when the dynamics depend on t."""
def f(y0, ts):
return odeint(lambda y, t: np.array([y[1] * -t, -1 * y[1] - 9.8 * np.sin(y[0])]), y0, ts)

y0 = [np.pi - 0.1, 0.0]
ts = np.linspace(0., 1., 11)

check_grads(f, (y0, ts), modes=["rev"], order=2)

if __name__ == '__main__':
pend_benchmark_odeint()
pend_check_grads()
weird_time_pendulum_check_grads()

0 comments on commit 04bae57

Please sign in to comment.