From 88c721656ec5334f5c9553f88342703b77e07750 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 25 Jun 2021 08:43:04 +0200 Subject: [PATCH] Ensure zeros from AD are generated on device. Fixes: #7093 --- jax/_src/abstract_arrays.py | 12 ++---------- jax/_src/lax/lax.py | 7 +++++++ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 349e631157ae..8a2df49fe52a 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -38,7 +38,8 @@ def make_shaped_array(x): def zeros_like_array(x): dtype = dtypes.canonicalize_dtype(dtypes.result_type(x)) - return zeros_like_shaped_array(ShapedArray(np.shape(x), dtype)) + aval = ShapedArray(np.shape(x), dtype) + return ad_util.zeros_like_aval(aval) array_types = {np.ndarray, np.bool_, np.int8, np.int16, np.int32, np.int64, @@ -51,15 +52,6 @@ def zeros_like_array(x): core.pytype_aval_mappings[t] = ConcreteArray ad_util.jaxval_zeros_likers[t] = zeros_like_array - -def zeros_like_shaped_array(aval): - assert isinstance(aval, ShapedArray) - if aval.dtype == dtypes.float0: - return np.zeros(aval.shape, dtypes.float0) - return np.broadcast_to(np.array(0, aval.dtype), aval.shape) - -ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array - core.literalable_types.update(array_types) def _zeros_like_python_scalar(t, x): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 99f9020ea34f..fb3d43d037ff 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1595,6 +1595,13 @@ def _device_put_raw(x, weak_type=None): aval = raise_to_shaped(core.get_aval(x), weak_type=weak_type) return xla.array_result_handler(None, aval)(*xla.device_put(x)) +def zeros_like_shaped_array(aval): + assert isinstance(aval, ShapedArray) + # The .astype is useful for float0 + return broadcast(np.array(0).astype(aval.dtype), aval.shape) + +ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array + def iota(dtype: DType, size: int) -> Array: """Wraps XLA's `Iota `_