Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip using a callback for default.qubit.jax and diff_method="parameter-shift" #3259

Closed
antalszava opened this issue Nov 3, 2022 · 5 comments · Fixed by #6788
Closed

Skip using a callback for default.qubit.jax and diff_method="parameter-shift" #3259

antalszava opened this issue Nov 3, 2022 · 5 comments · Fixed by #6788
Labels
enhancement ✨ New feature or request

Comments

@antalszava
Copy link
Contributor

Feature details

Specifying diff_method="parameter-shift" and transforming the QNode with jax.jit results in applying the JAX JIT interface. When used with the "default.qubit.jax" device, computing the gradient involves using jax.pure_callback. Using jax.pure_callback, however, is not required with this device because it is meant to natively used with JAX data structures.

Implementation

With the current architecture of PennyLane, the simplest change would involve update the JAX JIT interface to have logic for the default.qubit.jax device.

How important would you say this feature is?

1: Not important. Would be nice to have.

Additional information

No response

@astralcai
Copy link
Contributor

default.qubit.jax does not exist anymore

@josh146
Copy link
Member

josh146 commented Dec 12, 2024

Does the same issue still occur in the new implementation of JAX default.qubit?

@albi3ro
Copy link
Contributor

albi3ro commented Dec 12, 2024

But this is still a problem that we still haven't really fixed. Right now, we only send interface data to the device when diff_method="backprop". Otherwise, we assume the device requires pure numpy.

This was something that came up when porting to the new device interface. Who is responsible for converting the parameters to numpy. The device? Or the core pennylane workflow? So our current fix is always converting things to numpy unless we want to do backprop.

But we could definitely use a better solution. Especially if we want to jit more of the simulation, or if we want to be able to keep data on the gpu end-to-end.

We do have a hack where we tell the qnode interface=None and then pass in interface parameters to avoid the pure callback and numpy conversion. But I haven't been particularily happy with that solution.

This is also going to be a concern for the new program capture workflow. Who is responsible for applying the pure_callback?

@astralcai
Copy link
Contributor

The original premise of the issue was that we had a device that "is meant to natively used with JAX data structures", which we don't anymore? Maybe we can rephrase the issue to something that makes sense in the current context?

@albi3ro
Copy link
Contributor

albi3ro commented Dec 13, 2024

We have a device that can be natively used with all interface data structures.

willjmax pushed a commit that referenced this issue Feb 4, 2025
…ck` (#6788)

**Context:**

While we have logic for sampling with jax, it does not really integrate
very well into the workflow. While you can technically set
`diff_method=None` right now and jit the execution end-to-end, trying to
jit `diff_method=None` will cause incomprehensible error messages on
non-DQ devices.

We want to *forbid* differentiation `diff_method=None`, but keep a way
to jit a finite shot execution.

**Description of the Change:**

In order to allow jitting finite shot executions, we need a way for the
device to be able to configure whether or not the data is converted to
numpy. To do so, we simply add another property to the
`ExecutionConfig`, `convert_to_numpy`. If `False`, then we will not use
a `pure_callback` to convert the parameters to numpy. If `True`, we use
a `pure_callback` and convert the parameters to numpy.

**Benefits:**

Speed ups due to being able to jit the entire execution.


![image](https://github.com/user-attachments/assets/738076c6-7bb5-4c38-a8cc-97e138325dbc)


**Possible Drawbacks:**

`ExecutionConfig` gets an addtional property, making it more
complicated.

**Related GitHub Issues:**

Fixes #6054 Fixes #3259  Blocks #6770

---------

Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement ✨ New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants