Skip to content

Commit

Permalink
Change the default out-of-bounds behavior for jax.ops.segment_... to …
Browse files Browse the repository at this point in the history
…FILL_OR_DROP.

This matches the documented behavior.

Fixes #8634

PiperOrigin-RevId: 411617006
  • Loading branch information
hawkinsp authored and jax authors committed Nov 22, 2021
1 parent ad6ce74 commit 2bf23d2
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 16 deletions.
53 changes: 37 additions & 16 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,10 @@ def _segment_update(name: str,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None,
reducer: Optional[Callable] = None) -> Array:
reducer: Optional[Callable] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
jnp._check_arraylike(name, data, segment_ids)
mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
data = jnp.asarray(data)
segment_ids = jnp.asarray(segment_ids)
dtype = data.dtype
Expand All @@ -430,7 +432,7 @@ def _segment_update(name: str,
if num_buckets == 1:
return _scatter_update(
out, segment_ids, data, scatter_op, indices_are_sorted,
unique_indices, normalize_indices=False)
unique_indices, normalize_indices=False, mode=mode)

# Bucketize indices and perform segment_update on each bucket to improve
# numerical stability for operations like product and sum.
Expand All @@ -450,7 +452,8 @@ def segment_sum(data: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None) -> Array:
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the sum within segments of an array.
Similar to TensorFlow's `segment_sum
Expand All @@ -460,8 +463,7 @@ def segment_sum(data: Array,
data: an array with the values to be summed.
segment_ids: an array with integer dtype that indicates the segments of
`data` (along its leading axis) to be summed. Values can be repeated and
need not be sorted. Values outside of the range [0, num_segments) are
dropped and do not contribute to the sum.
need not be sorted.
num_segments: optional, an int with nonnegative value indicating the number
of segments. The default is set to be the minimum number of segments that
would support all indices in ``segment_ids``, calculated as
Expand All @@ -473,6 +475,9 @@ def segment_sum(data: Array,
bucket_size: size of bucket to group indices into. ``segment_sum`` is
performed on each bucket separately to improve numerical stability of
addition. Default ``None`` means no bucketing.
mode: a :class:`lax.GatherScatterMode` value describing how out-of-bounds
indices should be handled. By default, values outside of the range
[0, num_segments) are dropped and do not contribute to the sum.
Returns:
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
Expand All @@ -492,16 +497,18 @@ def segment_sum(data: Array,
>>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3)
DeviceArray([1, 5, 4], dtype=int32)
"""
return _segment_update("segment_sum", data, segment_ids, lax.scatter_add, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.sum)
return _segment_update(
"segment_sum", data, segment_ids, lax.scatter_add, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.sum, mode=mode)


def segment_prod(data: Array,
segment_ids: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None) -> Array:
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the product within segments of an array.
Similar to TensorFlow's `segment_prod
Expand All @@ -524,6 +531,9 @@ def segment_prod(data: Array,
bucket_size: size of bucket to group indices into. ``segment_prod`` is
performed on each bucket separately to improve numerical stability of
addition. Default ``None`` means no bucketing.
mode: a :class:`lax.GatherScatterMode` value describing how out-of-bounds
indices should be handled. By default, values outside of the range
[0, num_segments) are dropped and do not contribute to the sum.
Returns:
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
Expand All @@ -543,16 +553,18 @@ def segment_prod(data: Array,
>>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3)
DeviceArray([ 0, 6, 20], dtype=int32)
"""
return _segment_update("segment_prod", data, segment_ids, lax.scatter_mul, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.prod)
return _segment_update(
"segment_prod", data, segment_ids, lax.scatter_mul, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.prod, mode=mode)


def segment_max(data: Array,
segment_ids: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None) -> Array:
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the maximum within segments of an array.
Similar to TensorFlow's `segment_max
Expand All @@ -574,6 +586,9 @@ def segment_max(data: Array,
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_max`` is
performed on each bucket separately. Default ``None`` means no bucketing.
mode: a :class:`lax.GatherScatterMode` value describing how out-of-bounds
indices should be handled. By default, values outside of the range
[0, num_segments) are dropped and do not contribute to the sum.
Returns:
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
Expand All @@ -593,16 +608,18 @@ def segment_max(data: Array,
>>> jit(segment_max, static_argnums=2)(data, segment_ids, 3)
DeviceArray([1, 3, 5], dtype=int32)
"""
return _segment_update("segment_max", data, segment_ids, lax.scatter_max, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.max)
return _segment_update(
"segment_max", data, segment_ids, lax.scatter_max, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.max, mode=mode)


def segment_min(data: Array,
segment_ids: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
bucket_size: Optional[int] = None) -> Array:
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> Array:
"""Computes the minimum within segments of an array.
Similar to TensorFlow's `segment_min
Expand All @@ -624,6 +641,9 @@ def segment_min(data: Array,
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_min`` is
performed on each bucket separately. Default ``None`` means no bucketing.
mode: a :class:`lax.GatherScatterMode` value describing how out-of-bounds
indices should be handled. By default, values outside of the range
[0, num_segments) are dropped and do not contribute to the sum.
Returns:
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
Expand All @@ -643,5 +663,6 @@ def segment_min(data: Array,
>>> jit(segment_min, static_argnums=2)(data, segment_ids, 3)
DeviceArray([0, 2, 4], dtype=int32)
"""
return _segment_update("segment_min", data, segment_ids, lax.scatter_min, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.min)
return _segment_update(
"segment_min", data, segment_ids, lax.scatter_min, num_segments,
indices_are_sorted, unique_indices, bucket_size, jnp.min, mode=mode)
11 changes: 11 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,17 @@ def testSegmentSum(self):
expected = jnp.array([0, 0, 0, 13, 2, 7])
self.assertAllClose(ans, expected, check_dtypes=False)

def testSegmentSumOutOfBounds(self):
def fn(data, segment_ids):
return jax.ops.segment_sum(data, segment_ids, num_segments).sum()

data = np.array([0, 0], dtype=np.float32)
num_segments = 2
segment_ids = np.array([2, 3])
val, grad = jax.value_and_grad(fn)(data, segment_ids)
self.assertAllClose(val, np.array(0., np.float32))
self.assertAllClose(grad, np.array([0., 0.], np.float32))


@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list({
Expand Down

0 comments on commit 2bf23d2

Please sign in to comment.