Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solver.solve auto-jitting behaviour for pulse schedules breaks when JAX transformations are called from outside #175

Closed
DanPuzzuoli opened this issue Jan 6, 2023 · 2 comments
Milestone

Comments

@DanPuzzuoli
Copy link
Collaborator

Currently, if a Solver is configured to do pulse simulation, and everything is setup to use JAX (plus some other method-specific conditions), a call to Solver.solve with a pulse schedule will do an auto-jitting procedure internally to speed up simulation.

With the next release of terra, and the subsequent merging of #149, it will be possible for calls to Solver.solve with pulse schedules to be contained in functions being jit (or JAX-transformed in some other way).

@to24toro has observed that, when using the terra main branch and the branch of dynamics in #149:

  • Calling Solver.solve in a function being transformed, if the Solver has been configured for pulse simulation, raises an error.
  • If, instead of doing this, the schedule is manually transformed using InstructionToSignals, and the results are passed to Solver.solve as Signals (with a function being transformed), everything works.

This needs to be resolved before merging #149. There are two options I can see here:

  • The simplest option is to include detection of if a transformation is being performed in the decision to do the internal auto-jitting. One way of doing this is explained in this JAX discussion. If a transformation is occurring, then the regular code can be executed, and if not, the auto-jit version will be executed. While being fairly easy, this is actually warned against in the JAX discussion, as ideally you don't have branching in your code based on whether a transformation is being performed. This may lead to violation of "jit invariance", which is that the output of the function should be the same regardless of whether its been jit or not.
  • Inspect the logic of the auto-jitting and try to make it jittable as well. I suspect this may be difficult/convoluted (as the auto-jitting already is).

While I don't like the idea of having branching code that detects if jit is occurring, it will be much simpler to implement, and in the end maybe we can simplify the auto-jitting behaviour anyway once all forms of pulse schedule generation are transformable.

@DanPuzzuoli DanPuzzuoli added this to the pulse sim v1 milestone Jan 6, 2023
@to24toro
Copy link
Contributor

https://gist.github.com/to24toro/f7f9524fc4ea4f8853efb370bef33182 is an example code of error by calling Solver.solve already being transformed by jax-jit.

@DanPuzzuoli
Copy link
Collaborator Author

Closed by #149

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants