Skip to content

Commit

Permalink
Ensure zeros from AD are generated on device.
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula committed Jun 25, 2021
1 parent 9657521 commit 88c7216
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
12 changes: 2 additions & 10 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def make_shaped_array(x):

def zeros_like_array(x):
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
return zeros_like_shaped_array(ShapedArray(np.shape(x), dtype))
aval = ShapedArray(np.shape(x), dtype)
return ad_util.zeros_like_aval(aval)

array_types = {np.ndarray, np.bool_,
np.int8, np.int16, np.int32, np.int64,
Expand All @@ -51,15 +52,6 @@ def zeros_like_array(x):
core.pytype_aval_mappings[t] = ConcreteArray
ad_util.jaxval_zeros_likers[t] = zeros_like_array


def zeros_like_shaped_array(aval):
assert isinstance(aval, ShapedArray)
if aval.dtype == dtypes.float0:
return np.zeros(aval.shape, dtypes.float0)
return np.broadcast_to(np.array(0, aval.dtype), aval.shape)

ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array

core.literalable_types.update(array_types)

def _zeros_like_python_scalar(t, x):
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,13 @@ def _device_put_raw(x, weak_type=None):
aval = raise_to_shaped(core.get_aval(x), weak_type=weak_type)
return xla.array_result_handler(None, aval)(*xla.device_put(x))

def zeros_like_shaped_array(aval):
assert isinstance(aval, ShapedArray)
# The .astype is useful for float0
return broadcast(np.array(0).astype(aval.dtype), aval.shape)

ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array

def iota(dtype: DType, size: int) -> Array:
"""Wraps XLA's `Iota
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
Expand Down

0 comments on commit 88c7216

Please sign in to comment.