You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The gradients of segment sum are incorrect when any segment_ids >= num_segments.
Reproducer:
import jax
import jax.numpy as jnp
import numpy as np
data = np.array([0, 0], dtype=np.float32)
num_segments = 2
segment_ids = np.array([2, 3])
def fn(data, segment_ids):
return jax.ops.segment_sum(data, segment_ids, num_segments).sum()
value_and_grad_fn = jax.value_and_grad(fn)
val, grad = value_and_grad_fn(data, segment_ids)
print(val) # 0, Correct
print(grad) # [1., 1.], Incorrect, should be [0., 0.]
I suspect the reason behind this is that the gradient of the segment_sum is a gather/indexing operation, which as per the Sharp Bits returns the last value for indices outside the size of the array (unlike TensorFlow, which returned zeros in those cases), and:
Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of undefined behavior.
However, the documentation of segment_sum, does not at all dissuade the user from the use case of having indices outside the range [0, num_segments). And simply says:
Values outside of the range [0, num_segments) are dropped and do not contribute to the sum.
So to fix this bug, I would recommend to either update the docs, to strongly discourage the user from passing indices outside the range, or fixed the gradients of segment_sum, combining the gather with a where, to mask out any propagation of gradient for rows whose indices are outside the [0, num_segments) range (equivalent to using tf.gather style logic to return zeros when indexing outside the range, for the backwards pass of segment_sum).
Thanks in advance!
The text was updated successfully, but these errors were encountered:
(Happily, I'm already working slowly on cleaning up the out-of-bounds gather/scatter semantics, so fixing this issue is simply a question of changing a default mode.)
The gradients of segment sum are incorrect when any
segment_ids >= num_segments
.Reproducer:
I suspect the reason behind this is that the gradient of the
segment_sum
is agather
/indexing operation, which as per the Sharp Bits returns the last value for indices outside the size of the array (unlike TensorFlow, which returned zeros in those cases), and:However, the documentation of
segment_sum
, does not at all dissuade the user from the use case of having indices outside the range[0, num_segments)
. And simply says:So to fix this bug, I would recommend to either update the docs, to strongly discourage the user from passing indices outside the range, or fixed the gradients of
segment_sum
, combining the gather with awhere
, to mask out any propagation of gradient for rows whose indices are outside the[0, num_segments)
range (equivalent to usingtf.gather
style logic to return zeros when indexing outside the range, for the backwards pass ofsegment_sum
).Thanks in advance!
The text was updated successfully, but these errors were encountered: