You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Cross referencing #25679, which is related, I think!
For complex inputs, abs(x) is defined as sqrt(real(x)^2 + imag(x)^2), which means that the JVP rule is d(abs(x)) = (real(x) + imag(x)) / abs(x) which, for this example, is inf/inf + 0/inf = nan. It's possible that we could add special handling of infinities in this case, but it's not totally clear to me what I would expect the semantics to be (especially for imag(x) != 0). What do you think?
Description
I am getting unexpected
nan
gradient with the following example.What jax/jaxlib version are you using?
jax v0.4.38, jaxlib v0.4.38 numpy v1.26.4
Which accelerator(s) are you using?
CPU
Additional system info?
Python 3.11, Windows
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: