From 2bf23d29f46c3c4d7b06eedd2c3bee12658f9e0b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 22 Nov 2021 12:06:11 -0800 Subject: [PATCH] Change the default out-of-bounds behavior for jax.ops.segment_... to FILL_OR_DROP. This matches the documented behavior. Fixes https://github.com/google/jax/issues/8634 PiperOrigin-RevId: 411617006 --- jax/_src/ops/scatter.py | 53 ++++++++++++++++++++++---------- tests/lax_numpy_indexing_test.py | 11 +++++++ 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 24e8ee7bb6f6..587d57d6abcc 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -412,8 +412,10 @@ def _segment_update(name: str, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: Optional[int] = None, - reducer: Optional[Callable] = None) -> Array: + reducer: Optional[Callable] = None, + mode: Optional[lax.GatherScatterMode] = None) -> Array: jnp._check_arraylike(name, data, segment_ids) + mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode data = jnp.asarray(data) segment_ids = jnp.asarray(segment_ids) dtype = data.dtype @@ -430,7 +432,7 @@ def _segment_update(name: str, if num_buckets == 1: return _scatter_update( out, segment_ids, data, scatter_op, indices_are_sorted, - unique_indices, normalize_indices=False) + unique_indices, normalize_indices=False, mode=mode) # Bucketize indices and perform segment_update on each bucket to improve # numerical stability for operations like product and sum. @@ -450,7 +452,8 @@ def segment_sum(data: Array, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, - bucket_size: Optional[int] = None) -> Array: + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> Array: """Computes the sum within segments of an array. Similar to TensorFlow's `segment_sum @@ -460,8 +463,7 @@ def segment_sum(data: Array, data: an array with the values to be summed. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the sum. + need not be sorted. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as @@ -473,6 +475,9 @@ def segment_sum(data: Array, bucket_size: size of bucket to group indices into. ``segment_sum`` is performed on each bucket separately to improve numerical stability of addition. Default ``None`` means no bucketing. + mode: a :class:`lax.GatherScatterMode` value describing how out-of-bounds + indices should be handled. By default, values outside of the range + [0, num_segments) are dropped and do not contribute to the sum. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -492,8 +497,9 @@ def segment_sum(data: Array, >>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3) DeviceArray([1, 5, 4], dtype=int32) """ - return _segment_update("segment_sum", data, segment_ids, lax.scatter_add, num_segments, - indices_are_sorted, unique_indices, bucket_size, jnp.sum) + return _segment_update( + "segment_sum", data, segment_ids, lax.scatter_add, num_segments, + indices_are_sorted, unique_indices, bucket_size, jnp.sum, mode=mode) def segment_prod(data: Array, @@ -501,7 +507,8 @@ def segment_prod(data: Array, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, - bucket_size: Optional[int] = None) -> Array: + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> Array: """Computes the product within segments of an array. Similar to TensorFlow's `segment_prod @@ -524,6 +531,9 @@ def segment_prod(data: Array, bucket_size: size of bucket to group indices into. ``segment_prod`` is performed on each bucket separately to improve numerical stability of addition. Default ``None`` means no bucketing. + mode: a :class:`lax.GatherScatterMode` value describing how out-of-bounds + indices should be handled. By default, values outside of the range + [0, num_segments) are dropped and do not contribute to the sum. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -543,8 +553,9 @@ def segment_prod(data: Array, >>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3) DeviceArray([ 0, 6, 20], dtype=int32) """ - return _segment_update("segment_prod", data, segment_ids, lax.scatter_mul, num_segments, - indices_are_sorted, unique_indices, bucket_size, jnp.prod) + return _segment_update( + "segment_prod", data, segment_ids, lax.scatter_mul, num_segments, + indices_are_sorted, unique_indices, bucket_size, jnp.prod, mode=mode) def segment_max(data: Array, @@ -552,7 +563,8 @@ def segment_max(data: Array, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, - bucket_size: Optional[int] = None) -> Array: + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> Array: """Computes the maximum within segments of an array. Similar to TensorFlow's `segment_max @@ -574,6 +586,9 @@ def segment_max(data: Array, unique_indices: whether `segment_ids` is known to be free of duplicates. bucket_size: size of bucket to group indices into. ``segment_max`` is performed on each bucket separately. Default ``None`` means no bucketing. + mode: a :class:`lax.GatherScatterMode` value describing how out-of-bounds + indices should be handled. By default, values outside of the range + [0, num_segments) are dropped and do not contribute to the sum. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -593,8 +608,9 @@ def segment_max(data: Array, >>> jit(segment_max, static_argnums=2)(data, segment_ids, 3) DeviceArray([1, 3, 5], dtype=int32) """ - return _segment_update("segment_max", data, segment_ids, lax.scatter_max, num_segments, - indices_are_sorted, unique_indices, bucket_size, jnp.max) + return _segment_update( + "segment_max", data, segment_ids, lax.scatter_max, num_segments, + indices_are_sorted, unique_indices, bucket_size, jnp.max, mode=mode) def segment_min(data: Array, @@ -602,7 +618,8 @@ def segment_min(data: Array, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, - bucket_size: Optional[int] = None) -> Array: + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> Array: """Computes the minimum within segments of an array. Similar to TensorFlow's `segment_min @@ -624,6 +641,9 @@ def segment_min(data: Array, unique_indices: whether `segment_ids` is known to be free of duplicates. bucket_size: size of bucket to group indices into. ``segment_min`` is performed on each bucket separately. Default ``None`` means no bucketing. + mode: a :class:`lax.GatherScatterMode` value describing how out-of-bounds + indices should be handled. By default, values outside of the range + [0, num_segments) are dropped and do not contribute to the sum. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -643,5 +663,6 @@ def segment_min(data: Array, >>> jit(segment_min, static_argnums=2)(data, segment_ids, 3) DeviceArray([0, 2, 4], dtype=int32) """ - return _segment_update("segment_min", data, segment_ids, lax.scatter_min, num_segments, - indices_are_sorted, unique_indices, bucket_size, jnp.min) + return _segment_update( + "segment_min", data, segment_ids, lax.scatter_min, num_segments, + indices_are_sorted, unique_indices, bucket_size, jnp.min, mode=mode) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 0f042d441e34..9d34008d7eca 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1155,6 +1155,17 @@ def testSegmentSum(self): expected = jnp.array([0, 0, 0, 13, 2, 7]) self.assertAllClose(ans, expected, check_dtypes=False) + def testSegmentSumOutOfBounds(self): + def fn(data, segment_ids): + return jax.ops.segment_sum(data, segment_ids, num_segments).sum() + + data = np.array([0, 0], dtype=np.float32) + num_segments = 2 + segment_ids = np.array([2, 3]) + val, grad = jax.value_and_grad(fn)(data, segment_ids) + self.assertAllClose(val, np.array(0., np.float32)) + self.assertAllClose(grad, np.array([0., 0.], np.float32)) + @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list({