Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement jax.ops.index_mul. #2696

Merged
merged 4 commits into from
Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
index
index_update
index_add
index_mul
index_min
index_max

Expand Down
77 changes: 77 additions & 0 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,33 @@ def scatter_add(operand: Array, scatter_indices: Array, updates: Array,
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers)

def scatter_mul(operand: Array, scatter_indices: Array, updates: Array,
dimension_numbers: ScatterDimensionNumbers) -> Array:
"""Scatter-multiply operator.

Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
multiplication is used to combine updates and values from `operand`.

The semantics of scatter are complicated and its API is subject to change.

Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
relate.

Returns:
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = _reduction_jaxpr(mul, _abstractify(_const(operand, 1)))
return scatter_mul_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers)

def scatter_min(operand: Array, scatter_indices: Array, updates: Array,
dimension_numbers: ScatterDimensionNumbers) -> Array:
"""Scatter-min operator.
Expand Down Expand Up @@ -3458,6 +3485,39 @@ def _scatter_add_transpose_rule(t, operand, scatter_indices, updates, *,
slice_sizes=slice_sizes)
return [operand_t, None, update_t]

def _scatter_mul_transpose_rule(t, operand, scatter_indices, updates, *,
update_jaxpr, update_consts, dimension_numbers):
assert not ad.is_undefined_primal(scatter_indices)
if ad.is_undefined_primal(updates):
updates_shape = updates.aval.shape
else:
updates_shape = updates.shape
if t is ad_util.zero:
return [ad_util.zero, None, ad_util.zero]

operand_t = update_t = None
if ad.is_undefined_primal(operand):
operand_t = scatter_mul(t, scatter_indices, updates,
dimension_numbers=dimension_numbers)

if ad.is_undefined_primal(updates):
gather_dnums = GatherDimensionNumbers(
offset_dims=dimension_numbers.update_window_dims,
collapsed_slice_dims=dimension_numbers.inserted_window_dims,
start_index_map=dimension_numbers.scatter_dims_to_operand_dims)
slice_sizes = []
pos = 0
for i in range(len(t.shape)):
if i in dimension_numbers.inserted_window_dims:
slice_sizes.append(1)
else:
slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]])
pos += 1
update_t = gather(mul(t, operand), scatter_indices,
dimension_numbers=gather_dnums, slice_sizes=slice_sizes)
return [operand_t, None, update_t]


def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
update_jaxpr, update_consts, dimension_numbers):
operand, scatter_indices, updates = batched_args
Expand Down Expand Up @@ -3512,6 +3572,23 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
batching.primitive_batchers[scatter_add_p] = (
partial(_scatter_batching_rule, scatter_add))


scatter_mul_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
_scatter_translation_rule)

def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, **kw):
return mul(x, scatter_add(zeros_like_array(x), i, g,
dimension_numbers=dimension_numbers))

ad.defjvp(scatter_mul_p,
lambda g, x, i, y, **kw: scatter_mul_p.bind(g, i, y, **kw),
None,
_scatter_mul_jvp_rhs)
ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule
batching.primitive_batchers[scatter_mul_p] = (
partial(_scatter_batching_rule, scatter_mul))

# TODO(jlebar): Add derivatives.
scatter_min_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
Expand Down
4 changes: 3 additions & 1 deletion jax/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@
# limitations under the License.


from .scatter import index, index_add, index_update, index_min, index_max, segment_sum
from .scatter import (
index, index_add, index_mul, index_update, index_min, index_max, segment_sum
)
47 changes: 44 additions & 3 deletions jax/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def _scatter_update(x, idx, y, scatter_op):

x = np.asarray(x)
y = np.asarray(y)

# XLA gathers and scatters are very similar in structure; the scatter logic
# is more or less a transpose of the gather equivalent.
treedef, static_idx, dynamic_idx = np._split_index_for_jit(idx)
Expand All @@ -52,7 +51,8 @@ def _scatter_update(x, idx, y, scatter_op):
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @partial(jit, static_argnums=(2, 3, 4))
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx):
y = lax.convert_element_type(y, lax.dtype(x))
dtype = lax.dtype(x)
x, y = np._promote_dtypes(x, y)

idx = np._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = np._index_to_gather(np.shape(x), idx)
Expand All @@ -71,7 +71,8 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx):
inserted_window_dims=indexer.dnums.collapsed_slice_dims,
scatter_dims_to_operand_dims=indexer.dnums.start_index_map
)
return scatter_op(x, indexer.gather_indices, y, dnums)
out = scatter_op(x, indexer.gather_indices, y, dnums)
return lax.convert_element_type(out, dtype)


class _Indexable(object):
Expand Down Expand Up @@ -130,6 +131,46 @@ def index_add(x, idx, y):
"""
return _scatter_update(x, idx, y, lax.scatter_add)


def index_mul(x, idx, y):
"""Pure equivalent of :code:`x[idx] *= y`.

Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] *= y

Note the `index_mul` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.

Unlike the NumPy code :code:`x[idx] *= y`, if multiple indices refer to the
same location the updates will be multiplied. (NumPy would only apply the last
update, rather than multiplying the updates.) The order in which conflicting
updates are applied is implementation-defined and may be nondeterministic
(e.g., due to concurrency on some hardware platforms).

Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.

Returns:
An array.

>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_mul(x, jax.ops.index[2:4, 3:], 6.)
array([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
return _scatter_update(x, idx, y, lax.scatter_mul)


def index_min(x, idx, y):
"""Pure equivalent of :code:`x[idx] = minimum(x[idx], y)`.

Expand Down
12 changes: 7 additions & 5 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,15 +816,17 @@ def _update_shape(shape, indexer):
class UpdateOps(enum.Enum):
UPDATE = 0
ADD = 1
MIN = 2
MAX = 3
MUL = 2
MIN = 3
MAX = 4

@suppress_deprecated_indexing_warnings()
def onp_fn(op, indexer, x, y):
x = x.copy()
x[indexer] = {
UpdateOps.UPDATE: lambda: y,
UpdateOps.ADD: lambda: x[indexer] + y,
UpdateOps.MUL: lambda: x[indexer] * y,
UpdateOps.MIN: lambda: onp.minimum(x[indexer], y),
UpdateOps.MAX: lambda: onp.maximum(x[indexer], y),
}[op]()
Expand All @@ -834,6 +836,7 @@ def jax_fn(op, indexer, x, y):
return {
UpdateOps.UPDATE: ops.index_update,
UpdateOps.ADD: ops.index_add,
UpdateOps.MUL: ops.index_mul,
UpdateOps.MIN: ops.index_min,
UpdateOps.MAX: ops.index_max,
}[op](x, indexer, y)
Expand Down Expand Up @@ -919,7 +922,7 @@ def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
"op": op
} for name, index_specs in STATIC_INDEXING_TESTS
for shape, indexer in index_specs
for op in UpdateOps
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
for dtype in float_dtypes
for update_shape in _broadcastable_shapes(_update_shape(shape, indexer))
for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes)
Expand All @@ -928,8 +931,7 @@ def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
rng_factory, indexer, op):
rng = rng_factory()
jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add
jax_fn = lambda x, y: jax_op(x, indexer, y)
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
x = rng(shape, dtype)
y = rng(update_shape, update_dtype)
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
Expand Down