diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 8b3b96101a27..0de5be05477a 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2855,7 +2855,8 @@ def test_for_grad(self, jit_for, f, ref, body_shapes, n): self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol) self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol) - jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=3) + jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=3, + rtol=5e-3) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())