Skip to content

Commit

Permalink
Merge pull request #16933 from jakevdp:jnp-place
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553499383
  • Loading branch information
jax authors committed Aug 3, 2023
2 parents 0228bf7 + bd5a457 commit a22c477
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
44 changes: 34 additions & 10 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4949,18 +4949,42 @@ def _const(v):
return vectorize(lax.switch, excluded=(1,))(indices, funclist, x)


def _tile_to_size(arr: Array, size: int) -> Array:
assert arr.ndim == 1
if arr.size < size:
arr = tile(arr, int(np.ceil(size / arr.size)))
assert arr.size >= size
return arr[:size] if arr.size > size else arr


@util._wraps(np.place, lax_description="""
Numpy function :func:`numpy.place` is not available in JAX and will raise a
:class:`NotImplementedError`, because ``np.place`` modifies its arguments in-place,
and in JAX arrays are immutable. A JAX-compatible approach to array updates
can be found in :attr:`jax.numpy.ndarray.at`.
The semantics of :func:`numpy.place` is to modify arrays in-place, which JAX
cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.place` adds
the ``inplace`` parameter, which must be set to ``False`` by the user as a
reminder of this API difference.
""", extra_params="""
inplace : bool, default=True
If left to its default value of True, JAX will raise an error. This is because
the semantics of :func:`numpy.put` are to modify the array in-place, which is
not possible in JAX due to the immutability of JAX arrays.
""")
def place(*args, **kwargs):
raise NotImplementedError(
"jax.numpy.place is not implemented because JAX arrays cannot be modified in-place. "
"For functional approaches to updating array values, see jax.numpy.ndarray.at: "
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")
def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *,
inplace: bool = True) -> Array:
util.check_arraylike("place", arr, mask, vals)
data, mask_arr, vals_arr = asarray(arr), asarray(mask), ravel(vals)
if inplace:
raise ValueError(
"jax.numpy.place cannot modify arrays in-place, because JAX arrays are immutable. "
"Pass inplace=False to instead return an updated array.")
if data.size != mask_arr.size:
raise ValueError("place: arr and mask must be the same size")
if not vals_arr.size:
raise ValueError("Cannot place values from an empty array")
if not data.size:
return data
indices = where(mask_arr.ravel(), size=mask_arr.size, fill_value=mask_arr.size)[0]
vals_arr = _tile_to_size(vals_arr, len(indices))
return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape)


@util._wraps(np.put, lax_description="""
Expand All @@ -4980,7 +5004,7 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
arr, ind_arr, v_arr = asarray(a), ravel(ind), ravel(v)
if not arr.size or not ind_arr.size or not v_arr.size:
return arr
v_arr = tile(v_arr, int(np.ceil(len(ind_arr) / len(v_arr))))[:len(ind_arr)]
v_arr = _tile_to_size(v_arr, len(ind_arr))
if inplace:
raise ValueError(
"jax.numpy.put cannot modify arrays in-place, because JAX arrays are immutable. "
Expand Down
24 changes: 24 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5102,6 +5102,30 @@ def testFromString(self):
actual = jnp.fromstring(s, sep=',', dtype=int)
self.assertArraysEqual(expected, actual)

@jtu.sample_product(
a_shape=nonempty_nonscalar_array_shapes,
v_shape=nonempty_shapes,
dtype=jtu.dtypes.all,
)
def testPlace(self, a_shape, v_shape, dtype):
rng = jtu.rand_default(self.rng())
mask_rng = jtu.rand_bool(self.rng())

def args_maker():
a = rng(a_shape, dtype)
m = mask_rng(a_shape, bool)
v = rng(v_shape, dtype)
return a, m, v

def np_fun(a, m, v):
a_copy = a.copy()
np.place(a_copy, m, v)
return a_copy

jnp_fun = partial(jnp.place, inplace=False)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
a_shape=nonempty_nonscalar_array_shapes,
i_shape=all_shapes,
Expand Down

0 comments on commit a22c477

Please sign in to comment.