Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to jax.pure_callback in the JAX Jit interface #3244

Merged
merged 15 commits into from
Nov 3, 2022

Conversation

antalszava
Copy link
Contributor

@antalszava antalszava commented Nov 1, 2022

Context:
jax.pure_callback was released that supports jax.vmap and jax.jacobian. In functionality, it seems to be very similar to host_callback.call used by PennyLane.

Description of the Change:
Changes to using jax.pure_callback in the JAX JIT interface.

Benefits:
jax.vmap and jax.jacobian can be used with multiple expvals and a single use of qml.probs.

Note: further support is not added here because of the planned changes coming up with the new return types system.

Possible Drawbacks:
jax.pure_callback requires JAX version 0.3.17.

Related GitHub Issues:
Closes #3218

@codecov
Copy link

codecov bot commented Nov 1, 2022

Codecov Report

❗ No coverage uploaded for pull request base (master@55e0ffc). Click here to learn what that means.
The diff coverage is 100.00%.

@@            Coverage Diff            @@
##             master    #3244   +/-   ##
=========================================
  Coverage          ?   99.71%           
=========================================
  Files             ?      288           
  Lines             ?    25705           
  Branches          ?        0           
=========================================
  Hits              ?    25631           
  Misses            ?       74           
  Partials          ?        0           
Impacted Files Coverage Δ
pennylane/interfaces/jax_jit.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@PhilipVinc
Copy link
Contributor

FYI jax.pure_callback has been added in a relatively recent (last 2-3 months ago, IIRC) Jax release (v0.3.17).
I don't know how serious you are with error messages, but as the python package ecosystem is in a terrible mess (and you don't declare a dependency bound on Jax), you may want to avoid users raising issues you could throw an informative error on older jax versions saying something like "Pennylane-jax integration requires at least jax 0.3.17. Update jax and jaxlib by running pip install --upgrade jax jaxlib or equivalent depending on your package manager".

I'd recommend suggesting updating both jax and jaxlib because pip will not bump jaxlib if you don't ask him to, and updating only jax might break the environment...

@antalszava
Copy link
Contributor Author

We would like to test this with a remote device (e.g., IBMQ) and a GPU device.

@PhilipVinc
Copy link
Contributor

I have a gpu at hand and can try right now if you give me instructions.

@antalszava
Copy link
Contributor Author

Oh @PhilipVinc, thank you! Just saw your message, by this time we've also conducted these tests. 👍 Things worked well, so this PR will go through the review process.

As for general usage: we have the PennyLane-Lightning-GPU package that allows interfacing with Nvidia's cuQuantum library:
https://github.com/PennyLaneAI/pennylane-lightning-gpu

Doing

pip install pennylane-lightning[gpu] cuquantum

should be the simple way to installation.

@antalszava antalszava requested a review from rmoyard November 1, 2022 18:23
@antalszava antalszava changed the title jax.pure_callback Switch to jax.pure_callback Nov 1, 2022
@antalszava antalszava marked this pull request as ready for review November 1, 2022 21:52
@antalszava antalszava changed the title Switch to jax.pure_callback Switch to jax.pure_callback in the JAX Jit interface Nov 1, 2022
@antalszava
Copy link
Contributor Author

This PR was tested on GPUs and using remote device execution.

Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good, thanks @antalszava !

@Jaybsoni Jaybsoni added this to the v0.27.0 milestone Nov 2, 2022
Copy link
Contributor

@eddddddy eddddddy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @antalszava 💯! Just one question.

pennylane/interfaces/jax_jit.py Show resolved Hide resolved
tests/interfaces/test_jax_qnode.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] jax.jit(jax.grad()) of a circuit with shots crashes
5 participants