diff --git a/tests/test_qutip/test_entropy.py b/tests/test_qutip/test_entropy.py index d069d7e..3ecb9a5 100644 --- a/tests/test_qutip/test_entropy.py +++ b/tests/test_qutip/test_entropy.py @@ -8,7 +8,7 @@ import qutip_jax qutip.settings.core["auto_real_casting"] = False -qutip_jax.use_jax_backend() +qutip_jax.set_as_default() tol = 1e-6 # Tolerance for assertion with qutip.CoreOptions(default_dtype="jax"): diff --git a/tests/test_qutip/test_mcsolve.py b/tests/test_qutip/test_mcsolve.py index 13fb374..29c06d3 100644 --- a/tests/test_qutip/test_mcsolve.py +++ b/tests/test_qutip/test_mcsolve.py @@ -7,7 +7,7 @@ from functools import partial # Use JAX backend for QuTiP -qjax.use_jax_backend() +qjax.set_as_default() # Define time-dependent functions @partial(jax.jit, static_argnames=("omega",))