Skip to content

Commit

Permalink
Adding jax.Array to the JAX types in dispatch (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli authored Dec 13, 2022
1 parent 44a04bd commit 9c64ad1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
8 changes: 8 additions & 0 deletions qiskit_dynamics/dispatch/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@

JAX_TYPES = (DeviceArray, Tracer, JaxprTracer, JVPTracer)

try:
# This class was introduced in 0.4.0
from jax import Array

JAX_TYPES += (Array,)
except ImportError:
pass

try:
# This class is not in older versions of Jax
from jax.interpreters.partial_eval import DynamicJaxprTracer
Expand Down
4 changes: 4 additions & 0 deletions releasenotes/notes/jax-4-compatibility-8e3398e95f758dfe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
upgrade:
- |
The ``jax.Array`` class has been added to the dispatcher for compatibility with JAX 0.4.0.

0 comments on commit 9c64ad1

Please sign in to comment.