Skip to content

Commit

Permalink
Merge pull request #12036 from hawkinsp:gpu
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 469211693
  • Loading branch information
jax authors committed Aug 22, 2022
2 parents da4e79a + c247b9b commit 384776f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2855,7 +2855,8 @@ def test_for_grad(self, jit_for, f, ref, body_shapes, n):
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol,
atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=3)
jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=3,
rtol=5e-3)

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 384776f

Please sign in to comment.