From 30cd3b88fd6aaadb46fd8af7fd5ea66639ad9e3d Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 22 Apr 2024 16:25:37 +0000 Subject: [PATCH] Add support for copy kwarg in astype to match Array API --- CHANGELOG.md | 6 +++ jax/_src/numpy/array_methods.py | 6 ++- jax/_src/numpy/lax_numpy.py | 37 ++++++++++++++++--- .../array_api/_data_type_functions.py | 20 +++++++++- jax/numpy/__init__.pyi | 2 +- tests/lax_numpy_reducers_test.py | 9 ++++- tests/lax_numpy_test.py | 20 ++++++++++ 7 files changed, 87 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d67ea5fb46da..61a53e2a684f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,12 @@ Remember to align the itemized text with the first line of an item within a list * Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and related functions now raise an error, following a similar change in NumPy. +* Bug fixes + * {func}`jax.numpy.astype` will now always return a copy when `copy=True`. + Previously, no copy would be made when the output array would have the same + dtype as the input array. This may result in some increased memory usage. + The default value is set to `copy=False` to preserve backwards compatability. + ## jaxlib 0.4.27 ## jax 0.4.26 (April 3, 2024) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 98eea8887198..fd4661bdc4bd 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -31,11 +31,13 @@ import numpy as np import jax from jax import lax +from jax.sharding import Sharding from jax._src import core from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl from jax._src.lax import lax as lax_internal +from jax._src.lib import xla_client as xc from jax._src.numpy import lax_numpy from jax._src.numpy import reductions from jax._src.numpy import ufuncs @@ -55,7 +57,7 @@ # functions, which can themselves handle instances from any of these classes. -def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: +def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: """Copy the array and cast to a specified dtype. This is implemented via :func:`jax.lax.convert_element_type`, which may @@ -63,7 +65,7 @@ def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """ - return lax_numpy.astype(arr, dtype) + return lax_numpy.astype(arr, dtype, copy=copy, device=device) def _nbytes(arr: ArrayLike) -> int: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3767633deefc..5414a5e602e1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2272,17 +2272,42 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: In particular, the details of float-to-int and int-to-float casts are implementation dependent. """) -def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array: +def astype(x: ArrayLike, dtype: DTypeLike | None, + /, *, copy: bool = False, + device: xc.Device | Sharding | None = None) -> Array: util.check_arraylike("astype", x) x_arr = asarray(x) - del copy # unused in JAX + if dtype is None: dtype = dtypes.canonicalize_dtype(float_) dtypes.check_user_dtype_supported(dtype, "astype") - # convert_element_type(complex, bool) has the wrong semantics. - if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating): - return (x_arr != _lax_const(x_arr, 0)) - return lax.convert_element_type(x_arr, dtype) + if issubdtype(x_arr.dtype, complexfloating): + if dtypes.isdtype(dtype, ("integral", "real floating")): + warnings.warn( + "Casting from complex to real dtypes will soon raise a ValueError. " + "Please first use jnp.real or jnp.imag to take the real/imaginary " + "component of your input.", + DeprecationWarning, stacklevel=2 + ) + elif np.dtype(dtype) == bool: + # convert_element_type(complex, bool) has the wrong semantics. + x_arr = (x_arr != _lax_const(x_arr, 0)) + + # We offer a more specific warning than the usual ComplexWarning so we prefer + # to issue our warning. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ComplexWarning) + return _place_array( + lax.convert_element_type(x_arr, dtype), + device=device, copy=copy, + ) + +def _place_array(x, device=None, copy=None): + # TODO(micky774): Implement in future PRs as we formalize device placement + # semantics + if copy: + return _array_copy(x) + return x @util.implements(np.asarray, lax_description=_ARRAY_DOC) diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index 4f72fcba29d0..770d264c1c07 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import builtins import functools from typing import NamedTuple @@ -19,6 +21,9 @@ import jax.numpy as jnp +from jax._src.lib import xla_client as xc +from jax._src.sharding import Sharding +from jax._src import dtypes as _dtypes from jax.experimental.array_api._dtypes import ( bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 @@ -124,8 +129,19 @@ def _promote_types(t1, t2): raise ValueError("No promotion path for {t1} & {t2}") -def astype(x, dtype, /, *, copy=True): - return jnp.array(x, dtype=dtype, copy=copy) +def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None): + src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x) + if ( + src_dtype is not None + and _dtypes.isdtype(src_dtype, "complex floating") + and _dtypes.isdtype(dtype, ("integral", "real floating")) + ): + raise ValueError( + "Casting from complex to non-complex dtypes is not permitted. Please " + "first use jnp.real or jnp.imag to take the real/imaginary component of " + "your input." + ) + return jnp.astype(x, dtype, copy=copy, device=device) def can_cast(from_, to, /): diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 2740638041cd..fea18f6eb522 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -115,7 +115,7 @@ def asarray( ) -> Array: ... def asin(x: ArrayLike, /) -> Array: ... def asinh(x: ArrayLike, /) -> Array: ... -def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ... +def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: _Device | _Sharding | None = ...) -> Array: ... def atan(x: ArrayLike, /) -> Array: ... def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def atanh(x: ArrayLike, /) -> Array: ... diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 73100352c544..6c7ea59ef2d0 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -776,8 +776,13 @@ def test_f16_mean(self, dtype): for axis in list( range(-len(shape), len(shape)) ) + ([None] if len(shape) == 1 else [])], - dtype=all_dtypes + [None], - out_dtype=all_dtypes, + [dict(dtype=dtype, out_dtype=out_dtype) + for dtype in (all_dtypes+[None]) + for out_dtype in ( + complex_dtypes if np.issubdtype(dtype, np.complexfloating) + else all_dtypes + ) + ], include_initial=[False, True], ) @jtu.ignore_warning(category=NumpyComplexWarning) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ddc599792e63..cf7ee39e434b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3870,6 +3870,26 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + change_dtype=[True, False], + copy=[True, False], + ) + def testAstypeCopy(self, change_dtype, copy): + dtype = 'float32' if change_dtype else 'int32' + expect_copy = change_dtype or copy + x = jnp.arange(5, dtype='int32') + y = x.astype(dtype, copy=copy) + + self.assertEqual(y.dtype, dtype) + y.delete() + self.assertNotEqual(x.is_deleted(), expect_copy) + + def testAstypeComplexDowncast(self): + x = jnp.array(2.0+1.5j, dtype='complex64') + msg = "Casting from complex to non-complex dtypes will soon raise " + with self.assertWarns(DeprecationWarning, msg=msg): + x.astype('float32') + def testAstypeInt4(self): # Test converting from int4 to int8 x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)