diff --git a/tests/rigid_body_test.py b/tests/rigid_body_test.py index f8a1665..edf90c2 100644 --- a/tests/rigid_body_test.py +++ b/tests/rigid_body_test.py @@ -30,7 +30,6 @@ from jax.config import config as jax_config -jax_config.update('jax_disable_jit', True) import jax.numpy as jnp