Skip to content

Commit

Permalink
Make complex_arr.astype(bool) follow NumPy's semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 9, 2024
1 parent f1ae623 commit 3cff55a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Remember to align the itemized text with the first line of an item within a list
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
the old behavior by transforming the arguments via
`jax.tree.map(np.asarray, args)` before passing them to the callback.
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.

* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
Expand Down
8 changes: 7 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2263,11 +2263,17 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
implementation dependent.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
util.check_arraylike("astype", x)
x_arr = asarray(x)
del copy # unused in JAX
if dtype is None:
dtype = dtypes.canonicalize_dtype(float_)
dtypes.check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(x, dtype)
if dtype == np.dtype('bool') and issubdtype(x_arr.dtype, complexfloating):
# Complex convert_element_type has the wrong semantics for boolean conversion
return (x_arr != _lax_const(x_arr, 0))
else:
return lax.convert_element_type(x_arr, dtype)


@util.implements(np.asarray, lax_description=_ARRAY_DOC)
Expand Down
18 changes: 18 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3822,6 +3822,24 @@ def testAstype(self, from_dtype, to_dtype, use_method):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

@jtu.sample_product(
from_dtype=['int32', 'float32', 'complex64'],
use_method=[True, False],
)
def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
rng = self.rng()
args_maker = lambda: [rng.randn(3, 4).astype(from_dtype)]
if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0
np_op = lambda x: np.astype(x, to_dtype)
else:
np_op = lambda x: np.asarray(x).astype(to_dtype)
if use_method:
jnp_op = lambda x: jnp.asarray(x).astype(to_dtype)
else:
jnp_op = lambda x: jnp.astype(x, to_dtype)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

def testAstypeInt4(self):
# Test converting from int4 to int8
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)
Expand Down

0 comments on commit 3cff55a

Please sign in to comment.