Skip to content

Commit

Permalink
Merge pull request #51 from Ericgig/update.june
Browse files Browse the repository at this point in the history
Fix for updates in jax and qutip
  • Loading branch information
Ericgig authored Jun 6, 2024
2 parents b4ff4e1 + 2a3b5f1 commit 336f62f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: [3.9]
python-version: [3.11]
case-name: [defaults]

steps:
Expand Down
4 changes: 2 additions & 2 deletions src/qutip_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
(JaxArray, qutip.data.Dense, jaxarray_from_dense),
(qutip.data.Dense, JaxArray, dense_from_jaxarray, 2),
(JaxArray, JaxDia, jaxarray_from_jaxdia),
(JaxDia, JaxArray, jaxdia_from_jaxarray),
(qutip.data.Dia, JaxDia, dia_from_jaxdia),
(JaxDia, JaxArray, jaxdia_from_jaxarray, 1.2),
(qutip.data.Dia, JaxDia, dia_from_jaxdia, 2),
(JaxDia, qutip.data.Dia, jaxdia_from_dia),
]
)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def cte(t, A):
return A


# diffrax use clip with deprecated parameters
@pytest.mark.filterwarnings("ignore:Passing arguments 'a'")
@pytest.mark.parametrize("dtype", ("jax", "jaxdia"))
def test_ode_run(dtype):
with CoreOptions(default_dtype=dtype):
Expand All @@ -57,6 +59,7 @@ def test_ode_run(dtype):
np.testing.assert_allclose(result.expect[0], expected.expect[0], atol=1e-6)


@pytest.mark.filterwarnings("ignore:Passing arguments 'a'")
@pytest.mark.parametrize("dtype", ("jax", "jaxdia"))
def test_ode_step(dtype):
with CoreOptions(default_dtype=dtype):
Expand All @@ -79,6 +82,7 @@ def test_ode_step(dtype):
assert (solver.step(1) - ref_solver.step(1)).norm() <= 1e-6


@pytest.mark.filterwarnings("ignore:Passing arguments 'a'")
@pytest.mark.parametrize("dtype", ("jax", "jaxdia"))
def test_ode_grad(dtype):
with CoreOptions(default_dtype=dtype):
Expand Down

0 comments on commit 336f62f

Please sign in to comment.