-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JIT-ed calculation of Hessian [grad(grad)] fails with JAX #2163
Comments
Hi @quantshah
This very much comes down to how the JAX JIT interface uses
|
Thanks @antalszava for the explanation. Feel free to leave this issue open till there is a resolution or close it since this is an issue with Jax and not PL. |
Sure :) I might leave it open, just so that there's a point of reference if this becomes a question for others too. |
Note: this issue could be potentially resolved by a refactor to the JAX JIT interface. We have this on our radar and would like to look into a resolution in the coming weeks. In specific, at the moment the |
@antalszava just curious: if we move towards having the quantum device itself compute the VJP, this would require that the cotangent vector |
Likely not. JAX seems to assume that Specifically for params = jnp.array([0.1, 0.2])
@jax.custom_vjp
def wrapped_exec(params):
y = params ** 2, params ** 3
# don't need compute jacs here
return y
def wrapped_exec_fwd(params):
y = wrapped_exec(params)
jacs = jnp.diag(2 * params), jnp.diag(3 * params ** 2) # compute here
return y, jacs # don't need params here
def wrapped_exec_bwd(res, g):
jac1, jac2 = res
g1, g2 = g
return (g1 @ jac1) + (g2 @ jac2),
wrapped_exec.defvjp(wrapped_exec_fwd, wrapped_exec_bwd)
jax.jacobian(wrapped_exec)(params) |
Hi everyone, getting back to this thread as I saw that in Jax, there is a possibility to implement higher order gradients (VJPs) with I had a look at the implementation here: https://github.com/google/jax/blob/main/tests/host_callback_to_tf_test.py#L100 but haven't figured out completely what is happening in the custom backward pass that allows one to compute higher order gradients with Just putting this out here for reference in the future in case we look into this again and it is helpful. |
I just stumbled on this issue, I think a way to avoid this issue would be to define a new jax primitive operation The main 'complication' is that I'm not sure if you can feed a The interface is quite stable and has not changed in the last 2 years. Ps: If there's any interest for that, I don't have the time to put in, but I can provide some guidance. I did that already for two different packages. |
This is interesting, there is a nice example here of how this is done all the way upto JITing: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html But the problem still remains that we cannot use XLA operations to do the expval evaluation as the expval evaluation happens purely on the quantum device right? So there has to be a break in the computational graph somewhere (what Unless there is a way to use the XLA custom call (https://www.tensorflow.org/xla/custom_call) to get the value of the expval. |
Hi @PhilipVinc, thank you for the suggestion! 🎉 Personally, I'm new to creating a custom JAX primitive, any help and guidance here would definitely be appreciated. 🙂 Also wondering about the question that Shahnawaz mentioned: how could introducing the new primitive help with the specific error originally reported? It would seem that the issue is specific to the invocation of |
There's two graphs at play here. The one used during function transformations, pre- The issue with Then, you must also tell jax to what XLA operation the primitive corresponds to when he compiles ( You can always call ANY C-code. But, I guess, you could also enqueue an host_callback operation. |
What issue exactly? I'm not familiar with the depth of pennylane's source, so if you have a short example, even in pseudocode, that would help clarify |
Sure. 🙂 At the moment, the use of args = tuple(params) + (g,)
vjps = host_callback.call(
non_diff_wrapper,
args,
result_shape=jax.ShapeDtypeStruct((total_params,), dtype),
) Passing At the same time, passing NotImplementedError: JVP rule is implemented only for id_tap, not for call. It would seem that there may be two components to a solution here:
With those changes we should have no |
Are you aware of |
Wasn't aware of it! 😲 Will try this out thank you. 🙂 I see it's functionality in the works, but should be worthwhile to try because of the |
Expected behavior
I was trying to compute the Hessian and saw that the Jax interface breaks down if we have the JIT on. Without JIT, it works fine. The error seems to be due to the non-availability of JVPs in the
host_callback
bridge between PL and Jax. To make it work, just remove the @jax.jit from the definition of the circuit.@josh146 and I discussed this over slack and it seems a bit strange to have something that works with the JIT off not work when we simply JIT things.
Actual behavior
can't apply forward-mode autodiff (jvp) to a custom_vjp function.
JVP rule is implemented only for id_tap, not for call.
Additional information
No response
Source code
Tracebacks
System information
Existing GitHub issues
The text was updated successfully, but these errors were encountered: