-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix _eigs_jaxarray to be compatible with jit #52
Conversation
@Ericgig I have made some changes, please have a look. As far as I tested this way we can work around |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach is looking good.
Have you confirmed that jax support grad
for eigen problems?
@Ericgig You are right I missed that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now can eigs_jaxarray
itself be jit
if so let's move the jit from _eigs_jaxarray
to there.
We use jnp.complex128
for data. Not jnp.complex64
causing the tests to fail.
How about astype(data.dtype)
? This would be future proof if we decide to support both types.
Look good. |
@Ericgig I have added the suggested changes and the required tests are now passing. |
Yes, it works. @Ericgig |
Could you add/modify a test for it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you.
##Description
This PR is in an effort to enable jax.jit with qutip.core.metrics and entropy. It fixes _eigs_jaxarray to be compatible with jax.jit.
##Result
With this change trace_dist works with jax.jit