diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 032eb8eb584e..72b2ed7f9cc0 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -53,6 +53,7 @@ bincount as bincount, blackman as blackman, block as block, + bool_ as bool, # Array API alias for bool_ bool_ as bool_, broadcast_arrays as broadcast_arrays, broadcast_shapes as broadcast_shapes, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index bfb0561240da..bdff69ca57ad 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -156,6 +156,7 @@ def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... def blackman(M: int) -> Array: ... def block(arrays: Union[ArrayLike, Sequence[ArrayLike], Sequence[Sequence[ArrayLike]]]) -> Array: ... +bool: Any bool_: Any def broadcast_arrays(*args: ArrayLike) -> list[Array]: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 9e4f1811c7e9..c4715808f32f 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -175,7 +175,7 @@ def f(): return f @parameterized.parameters( - [dtype for dtype in [jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32, + [dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64, jnp.complex64, jnp.complex128] @@ -191,6 +191,9 @@ def testDtypeWrappers(self, dtype): prims = [eqn.primitive for eqn in jaxpr.eqns] self.assertEqual(prims, [lax.convert_element_type_p]) # No copy generated. + def testBoolDtypeAlias(self): + self.assertIs(jnp.bool, jnp.bool_) + @jtu.sample_product( dtype=float_dtypes + [object], allow_pickle=[True, False],