-
Notifications
You must be signed in to change notification settings - Fork 624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skip using a callback for default.qubit.jax
and diff_method="parameter-shift"
#3259
Comments
Does the same issue still occur in the new implementation of JAX default.qubit? |
But this is still a problem that we still haven't really fixed. Right now, we only send interface data to the device when This was something that came up when porting to the new device interface. Who is responsible for converting the parameters to numpy. The device? Or the core pennylane workflow? So our current fix is always converting things to numpy unless we want to do backprop. But we could definitely use a better solution. Especially if we want to jit more of the simulation, or if we want to be able to keep data on the gpu end-to-end. We do have a hack where we tell the qnode This is also going to be a concern for the new program capture workflow. Who is responsible for applying the |
The original premise of the issue was that we had a device that "is meant to natively used with JAX data structures", which we don't anymore? Maybe we can rephrase the issue to something that makes sense in the current context? |
We have a device that can be natively used with all interface data structures. |
…ck` (#6788) **Context:** While we have logic for sampling with jax, it does not really integrate very well into the workflow. While you can technically set `diff_method=None` right now and jit the execution end-to-end, trying to jit `diff_method=None` will cause incomprehensible error messages on non-DQ devices. We want to *forbid* differentiation `diff_method=None`, but keep a way to jit a finite shot execution. **Description of the Change:** In order to allow jitting finite shot executions, we need a way for the device to be able to configure whether or not the data is converted to numpy. To do so, we simply add another property to the `ExecutionConfig`, `convert_to_numpy`. If `False`, then we will not use a `pure_callback` to convert the parameters to numpy. If `True`, we use a `pure_callback` and convert the parameters to numpy. **Benefits:** Speed ups due to being able to jit the entire execution. ![image](https://github.com/user-attachments/assets/738076c6-7bb5-4c38-a8cc-97e138325dbc) **Possible Drawbacks:** `ExecutionConfig` gets an addtional property, making it more complicated. **Related GitHub Issues:** Fixes #6054 Fixes #3259 Blocks #6770 --------- Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Feature details
Specifying
diff_method="parameter-shift"
and transforming the QNode withjax.jit
results in applying the JAX JIT interface. When used with the"default.qubit.jax"
device, computing the gradient involves usingjax.pure_callback
. Usingjax.pure_callback
, however, is not required with this device because it is meant to natively used with JAX data structures.Implementation
With the current architecture of PennyLane, the simplest change would involve update the JAX JIT interface to have logic for the
default.qubit.jax
device.How important would you say this feature is?
1: Not important. Would be nice to have.
Additional information
No response
The text was updated successfully, but these errors were encountered: