Skip to content

Commit

Permalink
Implement jax.ops.index_mul. (jax-ml#2696)
Browse files Browse the repository at this point in the history
* Implement jax.ops.index_mul.

* Add index_mul to documentation.

* Fix RHS JVP rule for scatter_mul, fix test bug that meant it was not tested.

* Fix typo in docstring.
  • Loading branch information
hawkinsp authored and NeilGirdhar committed Apr 13, 2020
1 parent d8cd662 commit dbb771a
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/jax.ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
index
index_update
index_add
index_mul
index_min
index_max

Expand Down
77 changes: 77 additions & 0 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,33 @@ def scatter_add(operand: Array, scatter_indices: Array, updates: Array,
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers)

def scatter_mul(operand: Array, scatter_indices: Array, updates: Array,
dimension_numbers: ScatterDimensionNumbers) -> Array:
"""Scatter-multiply operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
multiplication is used to combine updates and values from `operand`.
The semantics of scatter are complicated and its API is subject to change.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
relate.
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = _reduction_jaxpr(mul, _abstractify(_const(operand, 1)))
return scatter_mul_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers)

def scatter_min(operand: Array, scatter_indices: Array, updates: Array,
dimension_numbers: ScatterDimensionNumbers) -> Array:
"""Scatter-min operator.
Expand Down Expand Up @@ -3458,6 +3485,39 @@ def _scatter_add_transpose_rule(t, operand, scatter_indices, updates, *,
slice_sizes=slice_sizes)
return [operand_t, None, update_t]

def _scatter_mul_transpose_rule(t, operand, scatter_indices, updates, *,
update_jaxpr, update_consts, dimension_numbers):
assert not ad.is_undefined_primal(scatter_indices)
if ad.is_undefined_primal(updates):
updates_shape = updates.aval.shape
else:
updates_shape = updates.shape
if t is ad_util.zero:
return [ad_util.zero, None, ad_util.zero]

operand_t = update_t = None
if ad.is_undefined_primal(operand):
operand_t = scatter_mul(t, scatter_indices, updates,
dimension_numbers=dimension_numbers)

if ad.is_undefined_primal(updates):
gather_dnums = GatherDimensionNumbers(
offset_dims=dimension_numbers.update_window_dims,
collapsed_slice_dims=dimension_numbers.inserted_window_dims,
start_index_map=dimension_numbers.scatter_dims_to_operand_dims)
slice_sizes = []
pos = 0
for i in range(len(t.shape)):
if i in dimension_numbers.inserted_window_dims:
slice_sizes.append(1)
else:
slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]])
pos += 1
update_t = gather(mul(t, operand), scatter_indices,
dimension_numbers=gather_dnums, slice_sizes=slice_sizes)
return [operand_t, None, update_t]


def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
update_jaxpr, update_consts, dimension_numbers):
operand, scatter_indices, updates = batched_args
Expand Down Expand Up @@ -3512,6 +3572,23 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
batching.primitive_batchers[scatter_add_p] = (
partial(_scatter_batching_rule, scatter_add))


scatter_mul_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
_scatter_translation_rule)

def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, **kw):
return mul(x, scatter_add(zeros_like_array(x), i, g,
dimension_numbers=dimension_numbers))

ad.defjvp(scatter_mul_p,
lambda g, x, i, y, **kw: scatter_mul_p.bind(g, i, y, **kw),
None,
_scatter_mul_jvp_rhs)
ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule
batching.primitive_batchers[scatter_mul_p] = (
partial(_scatter_batching_rule, scatter_mul))

# TODO(jlebar): Add derivatives.
scatter_min_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
Expand Down
4 changes: 3 additions & 1 deletion jax/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@
# limitations under the License.


from .scatter import index, index_add, index_update, index_min, index_max, segment_sum
from .scatter import (
index, index_add, index_mul, index_update, index_min, index_max, segment_sum
)
47 changes: 44 additions & 3 deletions jax/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def _scatter_update(x, idx, y, scatter_op):

x = np.asarray(x)
y = np.asarray(y)

# XLA gathers and scatters are very similar in structure; the scatter logic
# is more or less a transpose of the gather equivalent.
treedef, static_idx, dynamic_idx = np._split_index_for_jit(idx)
Expand All @@ -52,7 +51,8 @@ def _scatter_update(x, idx, y, scatter_op):
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @partial(jit, static_argnums=(2, 3, 4))
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx):
y = lax.convert_element_type(y, lax.dtype(x))
dtype = lax.dtype(x)
x, y = np._promote_dtypes(x, y)

idx = np._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = np._index_to_gather(np.shape(x), idx)
Expand All @@ -71,7 +71,8 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx):
inserted_window_dims=indexer.dnums.collapsed_slice_dims,
scatter_dims_to_operand_dims=indexer.dnums.start_index_map
)
return scatter_op(x, indexer.gather_indices, y, dnums)
out = scatter_op(x, indexer.gather_indices, y, dnums)
return lax.convert_element_type(out, dtype)


class _Indexable(object):
Expand Down Expand Up @@ -130,6 +131,46 @@ def index_add(x, idx, y):
"""
return _scatter_update(x, idx, y, lax.scatter_add)


def index_mul(x, idx, y):
"""Pure equivalent of :code:`x[idx] *= y`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] *= y
Note the `index_mul` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike the NumPy code :code:`x[idx] *= y`, if multiple indices refer to the
same location the updates will be multiplied. (NumPy would only apply the last
update, rather than multiplying the updates.) The order in which conflicting
updates are applied is implementation-defined and may be nondeterministic
(e.g., due to concurrency on some hardware platforms).
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_mul(x, jax.ops.index[2:4, 3:], 6.)
array([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
return _scatter_update(x, idx, y, lax.scatter_mul)


def index_min(x, idx, y):
"""Pure equivalent of :code:`x[idx] = minimum(x[idx], y)`.
Expand Down
12 changes: 7 additions & 5 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,15 +816,17 @@ def _update_shape(shape, indexer):
class UpdateOps(enum.Enum):
UPDATE = 0
ADD = 1
MIN = 2
MAX = 3
MUL = 2
MIN = 3
MAX = 4

@suppress_deprecated_indexing_warnings()
def onp_fn(op, indexer, x, y):
x = x.copy()
x[indexer] = {
UpdateOps.UPDATE: lambda: y,
UpdateOps.ADD: lambda: x[indexer] + y,
UpdateOps.MUL: lambda: x[indexer] * y,
UpdateOps.MIN: lambda: onp.minimum(x[indexer], y),
UpdateOps.MAX: lambda: onp.maximum(x[indexer], y),
}[op]()
Expand All @@ -834,6 +836,7 @@ def jax_fn(op, indexer, x, y):
return {
UpdateOps.UPDATE: ops.index_update,
UpdateOps.ADD: ops.index_add,
UpdateOps.MUL: ops.index_mul,
UpdateOps.MIN: ops.index_min,
UpdateOps.MAX: ops.index_max,
}[op](x, indexer, y)
Expand Down Expand Up @@ -919,7 +922,7 @@ def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
"op": op
} for name, index_specs in STATIC_INDEXING_TESTS
for shape, indexer in index_specs
for op in UpdateOps
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
for dtype in float_dtypes
for update_shape in _broadcastable_shapes(_update_shape(shape, indexer))
for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes)
Expand All @@ -928,8 +931,7 @@ def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
rng_factory, indexer, op):
rng = rng_factory()
jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add
jax_fn = lambda x, y: jax_op(x, indexer, y)
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
x = rng(shape, dtype)
y = rng(update_shape, update_dtype)
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
Expand Down

0 comments on commit dbb771a

Please sign in to comment.