-
Notifications
You must be signed in to change notification settings - Fork 620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] jax.jit(jax.grad()) of a circuit with shots crashes #3218
Comments
Hi @PhilipVinc, thank you for opening this issue. Can you please confirm whether you get this issue with PennyLane v0.26.0 too? Or only with the dev version? |
Hi @CatalinaAlbornoz . I just tried and the issue also exists with 0.26.0. The diagnosis is the same: the host_callback should be fed the |
Hi @PhilipVinc, thank you for the report and comments! 🙂 We'll be looking into this and come back with our findings shortly. |
Thank you, actually! If you would like any opinion or discuss more interactively some of those Jax-related mysteries on a call, feel free to drop me an email. |
Just went through the description and tried the example myself. Agree on the points - as the key is not passed in, a tracer is leaking when jitting which will lead to a leaked tracer error. This is definitely a byproduct of the design done in PennyLane and likely requires a major design change because the pipeline described originally is a Device API that is used by (almost) all devices in our ecosystem. Having said that, I'll continue the investigation and try to come up with a solution that could benefit the use case. On the side, one question I have is, why would we like PRNGKey to be passed in the function Could the following (executable) solution work still, where we pass in from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import numpy as np
import pennylane as qml
phys_qubits = 2
pars_q = np.random.rand(3)
def minimal_circ(params, dev):
@qml.qnode(dev, interface="jax",diff_method="parameter-shift")
def _measure_operator():
qml.RY(params[0],wires=0)
qml.RY(params[1],wires=1)
op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
return qml.expval(op)
res = _measure_operator()
return res
grad_fun = jax.grad(minimal_circ)
prng_key = jax.random.PRNGKey(0)
dev = qml.device("default.qubit.jax", wires=tuple(range(phys_qubits)), shots=1000, prng_key=prng_key)
jax.jit(grad_fun, static_argnums=[1])(pars_q, dev)
|
Appreciate this a lot, definitely keen to hear more about what we could improve! 🙏
To answer the question in the brackets: though the callback calls Python code, it encapsulates the quantum device execution which may happen using a remote simulator/a remote QPU. That's the main motivation behind us using |
Yes, what you propose would work, in principle. However, Jax retraces/recompiles every time you change some In my use case, where I have an hybrid structure coupling a Neural Network and a quantum circuit, re-compiling leads to very, very large increases in computational time (at least, when not using shots.). As a side note, to make this work, you'd need to correctly compute the |
Yeah, this definitely makes sense. Though for the particular case of a local
I'm not sure I understand what is limiting here. Probably also because I fail to see exactly where the device is My very uninformed understanding is that you are taking the Object-Oriented/Pythonic approach of passing around the method of an object, which implicitly captures (in a somewhat opaque manner) the underlying instance. The standard way to do this in functional programming would be to split the functions from the data structure, so that you are obliged, in a sense, to pass the data structure as an argument. Jax likes that because it can do its tracer magic on the arguments. |
This is a great point! We can leverage this specifically for As for the other comments, yes, the OOP and the more implicit pipeline are more disadvantageous in this specific case. While PennyLane does follow functional approaches, parts of it are definitely not purely functional. After some more local exploration, a fix should be doable and we'll be focusing on working towards having it in the code base as soon as we can. We have a release coming up soon ( I'll be commenting on the progress as this work moves along. 👍 |
@PhilipVinc was there a specific reason for using the Switching to the C++-based from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import numpy as np
import pennylane as qml
phys_qubits = 2
pars_q = np.random.rand(3)
def minimal_circ(params):
dev = qml.device("lightning.qubit", wires=tuple(range(phys_qubits)), shots=100)
@qml.qnode(dev, interface="jax-jit",diff_method="parameter-shift", cache=None)
def _measure_operator():
qml.RY(params[0],wires=0)
qml.RY(params[1],wires=1)
op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
return qml.expval(op)
res = _measure_operator()
return res
grad_fun = jax.grad(minimal_circ)
fun = jax.jit(grad_fun)
for _ in range(5):
print(fun(pars_q))
On top of this, it could be interesting to see how the originally suggested |
not really, no. I think I used it because at first I (erroneously) thought I could not mix and match different devices and
Probably worse. Jax is especially bad at using more than 1 or 2 cores on CPU (I think its BLAS implementation is particularly conservative before switching to multi-threading) and I wouldn't be surprised if any purpose-written C kernel could beat XLA (Jax compiler) when applying gates... |
Thanks for the snippet! I'll surely try this out. Just to understand... how will the RNG seed work in that case? is it using some internal state that gets updated every time he calls back into python/lightnight? |
The sampling (including the RNG seed generation) is completely encompassed in the function that is invoked by The function passed to |
@antalszava thanks a lot for the snippet, indeed it pushes us forward! To get to the bottom of what we'd need, here is a more complicated MWE that breaks down once we start playing with vmap. from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import numpy as np
import pennylane as qml
phys_qubits = 2
n_configs = 5
pars_q = np.random.rand(n_configs,2)
def minimal_circ(params):
dev = qml.device("lightning.qubit", wires=tuple(range(phys_qubits)), shots=100)
@qml.qnode(dev, interface="jax-jit",diff_method="parameter-shift", cache=None)
def _measure_operator():
qml.RY(params[0],wires=0)
qml.RY(params[1],wires=1)
op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
return qml.expval(op)
res = _measure_operator()
return res
# this works
jax.jit(minimal_circ)(pars_q[0])
jax.jit(jax.grad(minimal_circ))(pars_q[0])
# Vmapping the circuit does not work
minimal_circ_batch = jax.vmap(minimal_circ)
minimal_circ_batch(pars_q)
# Getting the jacobian (aka, vmap of grad) does not work
jax.vmap(jax.grad(minimal_circ))(pars_q) In short, we want to run the same circuit for several different parameters, and compute the gradient of those. At least for the forward pass, I'd expect that your lightning device should support batching (vmap, in jax parlance) and speed up the calculation. I'm unsure if that would also work for the backward pass... Of course, host callback does not support vmapping. In that case, if your underlying C code does support it you could transition to A note: your lightning |
Ok, I just noticed that you implicitly support minimal_circ(pars_q) works and does what I wanted to do with jax.vmap(minimal_circ)(pars_q) So I'm even more sure that About the gradient... It's going to be a bit more tricky. Note: If someone wonders why would I want to use |
Hi @PhilipVinc, the behaviour with The use of |
@antalszava changing all Though performance will be sub-optimal, because he'll be inserting a loop. But I think it's possible to make it work without switching to jvp... |
You're right! 🎉 Tried it in #3244. We'll check how JAX's vectorized version could be wired in with PennyLane's parameter broadcasting. As for the JVPs: it's a change we've been contemplating anyways because it allows both |
Can I open an issue to track properly vectorising the grad call when using the jax interface? |
For sure! Just wanted to leave a comment here, mentioning that although this issue is being closed, we'd like to track the improvements we discussed. |
@PhilipVinc could you also open an issue describing the improvements we could make to the parameter broadcasting UI? |
Yes it's on my to do list. I have a deadline on monday so my eta is 10 days to be able to phrase something decently. And thanks for implementing this thing. It's very helpful for us being able to work with Jax without worrying too much about how to work around issues.. |
For sure! Sounds good. 👍 I've opened an issue to Skip using a callback for |
Expected behavior
Jitting the gradient of a
QNode
with a device using shots, when setting the PRNGKey leads to a crash. I would expect this to work.Below there is a snippet that easily reproduces this issue on master. Do note that if you remove the
jax.jit
the gradient works, but this is by accident.I think I know what is causing the bug, but the explanation is a bit involved, I will first give you a TLDR, then I will show you exactly where the crash happens, then I will reason on what is happening there.
TLDR
The problem arises because you are storing a tracer in
DEfaultQubiJax._prng_key
, but you are not correctly passing this prng key as an argument of the host callback injax_jit.py:_execute
. Conceptually, you should pass as anarg
of the callback theprng key
like you do for the parameters.Instead, the
device
and therefore the_prng_key
is captured in a nested series of lambdas/functions called from the callback. Therefore when the callback is executed, he encounters a tracer object for the prng key which is not substituted with concrete values and crashes.Observing where the crash happens
As I am not very familiar with the interiors of Pennylane, and as this crash happens inside of a callback, preventing proper stack traces from being printed, I had to resort to a very primitive way of debugging.
I have added several print statements in the various functions of penny lane. You can install my copy of 'instrumented penny lane by running'
Using this copy, and running the snippet below, you will see the following messages printed:
The call chain at the point of the crash is the following:
this crash is happening after the compilation. Inside of the
_execute
, when executing the host callback we hit theexecute_fn
in the callback.The
execute_fn
is a series of wrappers aroundQubitDevice.batch_execute
, which then callsself.execute
.In
QubitDevice.execute
there is a branch that, if a finite number of shots is specified, it callsgenerate_samples
and thenDefaultQubitJax.sample_basis_states
.In this function we use the
._prng_key
to execute some jax random functions. But as I said before, this is all being executed inside of a callback, so there should be no tracers there! Instead, as the device was captured in some lambdas, the device has a tracer as a prngKey and leads to a crash.Possible solution
The solution is to pass the prng key as an argument to the callback. In a sense, you'd need to do something similar to
cp_tape
for the prng key of the device.However, this seems complicated to do because you are not passing the device itself as an argument to those functions, but it captured inside of lambdas (I think). But maybe someone who is more familiar with pennylane @antalszava might know how to do this?
Source code
System information
Existing GitHub issues
The text was updated successfully, but these errors were encountered: