Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected NaN gradient of jnp.abs at ±inf + 0j #25681

Open
lfaucheux opened this issue Dec 24, 2024 · 3 comments
Open

Unexpected NaN gradient of jnp.abs at ±inf + 0j #25681

lfaucheux opened this issue Dec 24, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@lfaucheux
Copy link

lfaucheux commented Dec 24, 2024

Description

I am getting unexpected nan gradient with the following example.

>>> import jax
>>> (g := jax.grad(jax.numpy.abs))(float('inf') + 0j)
Array(nan+nanj, dtype=complex128)
# >>> g(float('inf'))
# Array(1., dtype=float64, weak_type=True)

What jax/jaxlib version are you using?

jax v0.4.38, jaxlib v0.4.38 numpy v1.26.4

Which accelerator(s) are you using?

CPU

Additional system info?

Python 3.11, Windows

NVIDIA GPU info

No response

@lfaucheux lfaucheux added the bug Something isn't working label Dec 24, 2024
@lfaucheux lfaucheux changed the title Unexpected NaN gradient of jnp.abs at inf + 0j Unexpected NaN gradient of jnp.abs at ±inf + 0j Dec 24, 2024
@dfm dfm self-assigned this Dec 29, 2024
@dfm
Copy link
Collaborator

dfm commented Jan 6, 2025

Cross referencing #25679, which is related, I think!

For complex inputs, abs(x) is defined as sqrt(real(x)^2 + imag(x)^2), which means that the JVP rule is d(abs(x)) = (real(x) + imag(x)) / abs(x) which, for this example, is inf/inf + 0/inf = nan. It's possible that we could add special handling of infinities in this case, but it's not totally clear to me what I would expect the semantics to be (especially for imag(x) != 0). What do you think?

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 6, 2025

This is also related to previous discussions of the grad of abs at complex zero. See e.g. #10515 (comment)

@pearu
Copy link
Collaborator

pearu commented Jan 8, 2025

Analytically, we have

grad(abs))(x + y*1j) == sin(a) - cos(a)*1j
a = atan2(x, y)

Since atan2(x, y) is defined for any x, y including infinities, the result of grad(abs))(x + y*1j) is well defined. In this particular case, we have

x = inf
y = 0
a = atan2(x, y) -> pi/2
grad(abs))(x + y*1j) = sin(a) - cos(a)*1j -> 1 - 0j

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants