Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect gradient for segment sum #8634

Closed
alvarosg opened this issue Nov 20, 2021 · 1 comment · Fixed by #8651
Closed

Incorrect gradient for segment sum #8634

alvarosg opened this issue Nov 20, 2021 · 1 comment · Fixed by #8651
Assignees
Labels
bug Something isn't working

Comments

@alvarosg
Copy link
Contributor

The gradients of segment sum are incorrect when any segment_ids >= num_segments.
Reproducer:

import jax
import jax.numpy as jnp
import numpy as np

data = np.array([0, 0], dtype=np.float32)
num_segments = 2
segment_ids = np.array([2, 3])

def fn(data, segment_ids):
  return jax.ops.segment_sum(data, segment_ids, num_segments).sum()

value_and_grad_fn = jax.value_and_grad(fn)
val, grad = value_and_grad_fn(data, segment_ids)
print(val)  # 0, Correct
print(grad)  # [1., 1.], Incorrect, should be [0., 0.]

I suspect the reason behind this is that the gradient of the segment_sum is a gather/indexing operation, which as per the Sharp Bits returns the last value for indices outside the size of the array (unlike TensorFlow, which returned zeros in those cases), and:

Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of undefined behavior.

However, the documentation of segment_sum, does not at all dissuade the user from the use case of having indices outside the range [0, num_segments). And simply says:

Values outside of the range [0, num_segments) are dropped and do not contribute to the sum.

So to fix this bug, I would recommend to either update the docs, to strongly discourage the user from passing indices outside the range, or fixed the gradients of segment_sum, combining the gather with a where, to mask out any propagation of gradient for rows whose indices are outside the [0, num_segments) range (equivalent to using tf.gather style logic to return zeros when indexing outside the range, for the backwards pass of segment_sum).

Thanks in advance!

@alvarosg alvarosg added the bug Something isn't working label Nov 20, 2021
@alvarosg alvarosg changed the title Incorrect gather outside of segment sum Incorrect gradient for segment sum Nov 22, 2021
@hawkinsp hawkinsp self-assigned this Nov 22, 2021
copybara-service bot pushed a commit that referenced this issue Nov 22, 2021
…FILL_OR_DROP.

This matches the documented behavior.

Fixes #8634

PiperOrigin-RevId: 411617006
copybara-service bot pushed a commit that referenced this issue Nov 22, 2021
…FILL_OR_DROP.

This matches the documented behavior.

Fixes #8634

PiperOrigin-RevId: 411617006
@hawkinsp
Copy link
Collaborator

#8651 should fix.

(Happily, I'm already working slowly on cleaning up the out-of-bounds gather/scatter semantics, so fixing this issue is simply a question of changing a default mode.)

copybara-service bot pushed a commit that referenced this issue Nov 22, 2021
…FILL_OR_DROP.

This matches the documented behavior.

Fixes #8634

PiperOrigin-RevId: 411617006
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants