-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch to jax.pure_callback
in the JAX Jit interface
#3244
Conversation
Codecov Report
@@ Coverage Diff @@
## master #3244 +/- ##
=========================================
Coverage ? 99.71%
=========================================
Files ? 288
Lines ? 25705
Branches ? 0
=========================================
Hits ? 25631
Misses ? 74
Partials ? 0
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
FYI I'd recommend suggesting updating both jax and jaxlib because pip will not bump jaxlib if you don't ask him to, and updating only jax might break the environment... |
We would like to test this with a remote device (e.g., IBMQ) and a GPU device. |
I have a gpu at hand and can try right now if you give me instructions. |
Oh @PhilipVinc, thank you! Just saw your message, by this time we've also conducted these tests. 👍 Things worked well, so this PR will go through the review process. As for general usage: we have the PennyLane-Lightning-GPU package that allows interfacing with Nvidia's cuQuantum library: Doing pip install pennylane-lightning[gpu] cuquantum should be the simple way to installation. |
jax.pure_callback
jax.pure_callback
in the JAX Jit interface
…into jax_pure_callback
This PR was tested on GPUs and using remote device execution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good, thanks @antalszava !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good @antalszava 💯! Just one question.
Context:
jax.pure_callback
was released that supportsjax.vmap
andjax.jacobian
. In functionality, it seems to be very similar tohost_callback.call
used by PennyLane.Description of the Change:
Changes to using
jax.pure_callback
in the JAX JIT interface.Benefits:
jax.vmap
andjax.jacobian
can be used with multiple expvals and a single use ofqml.probs
.Note: further support is not added here because of the planned changes coming up with the new return types system.
Possible Drawbacks:
jax.pure_callback
requires JAX version 0.3.17.Related GitHub Issues:
Closes #3218