Skip to content

Commit

Permalink
Merge pull request #19470 from jakevdp:device-arg
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 601205471
  • Loading branch information
jax authors committed Jan 24, 2024
2 parents 831c25f + d55cd7c commit 8b81555
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 15 deletions.
35 changes: 26 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from jax import jit
from jax import errors
from jax import lax
from jax.sharding import Sharding, SingleDeviceSharding
from jax.tree_util import tree_leaves, tree_flatten, tree_map

from jax._src import api_util
Expand All @@ -58,6 +59,7 @@
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator, PrecisionLike)
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy import util
Expand Down Expand Up @@ -2257,16 +2259,28 @@ def empty_like(prototype: ArrayLike | DuckTypedArray,
return zeros_like(prototype, dtype=dtype, shape=shape)


def _maybe_device_put(arr: Array, device: xc.Device | Sharding | None) -> Array:
return arr if device is None else jax.device_put(arr, device)

def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None:
if isinstance(device, xc.Device):
return SingleDeviceSharding(device)
else:
return device


@util._wraps(np.full)
def full(shape: Any, fill_value: ArrayLike,
dtype: DTypeLike | None = None) -> Array:
dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "full")
util.check_arraylike("full", fill_value)

if ndim(fill_value) == 0:
shape = canonicalize_shape(shape)
return lax.full(shape, fill_value, dtype)
return lax.full(shape, fill_value, dtype, sharding=_normalize_to_sharding(device))
else:
return broadcast_to(asarray(fill_value, dtype=dtype), shape)
return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)


@util._wraps(np.full_like)
Expand All @@ -2289,30 +2303,33 @@ def full_like(a: ArrayLike | DuckTypedArray,


@util._wraps(np.zeros)
def zeros(shape: Any, dtype: DTypeLike | None = None) -> Array:
def zeros(shape: Any, dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m)
dtypes.check_user_dtype_supported(dtype, "zeros")
shape = canonicalize_shape(shape)
return lax.full(shape, 0, _jnp_dtype(dtype))
return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))

@util._wraps(np.ones)
def ones(shape: Any, dtype: DTypeLike | None = None) -> Array:
def ones(shape: Any, dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
shape = canonicalize_shape(shape)
dtypes.check_user_dtype_supported(dtype, "ones")
return lax.full(shape, 1, _jnp_dtype(dtype))
return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))

@util._wraps(np.empty, lax_description="""\
Because XLA cannot create uninitialized arrays, the JAX version will
return an array initialized with zeros.""")
def empty(shape: Any, dtype: DTypeLike | None = None) -> Array:
def empty(shape: Any, dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m)
dtypes.check_user_dtype_supported(dtype, "empty")
return zeros(shape, dtype)
return zeros(shape, dtype, device=device)

def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore
if isinstance(dtype, int) and isinstance(shape, int):
Expand Down
16 changes: 12 additions & 4 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ from jax._src.lax.slicing import GatherScatterMode
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape
from jax.numpy import fft as fft, linalg as linalg
from jax.sharding import Sharding as _Sharding
import numpy as _np

_T = TypeVar('_T')

_Axis = Union[None, int, Sequence[int]]

# TODO(jakevdp): use xla_client.Device here
_Device = Any

ComplexWarning: type

_deprecations: dict[str, tuple[str, Any]]
Expand Down Expand Up @@ -300,7 +304,8 @@ def einsum(
) -> Array: ...

def einsum_path(subscripts, *operands, optimize = ...): ...
def empty(shape: Any, dtype: Optional[DTypeLike] = ...) -> Array: ...
def empty(shape: Any, dtype: Optional[DTypeLike] = ...,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def empty_like(prototype: Union[ArrayLike, DuckTypedArray],
dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
Expand Down Expand Up @@ -357,7 +362,8 @@ def fromstring(
string: str, dtype: DTypeLike = ..., count: int = ..., *, sep: str
) -> Array: ...
def full(shape: Any, fill_value: ArrayLike,
dtype: Optional[DTypeLike] = ...) -> Array: ...
dtype: Optional[DTypeLike] = ..., *,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def full_like(a: Union[ArrayLike, DuckTypedArray],
fill_value: ArrayLike, dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
Expand Down Expand Up @@ -599,7 +605,8 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ...
number = _np.number
object_ = _np.object_
ogrid: _Ogrid
def ones(shape: Any, dtype: Optional[DTypeLike] = ...) -> Array: ...
def ones(shape: Any, dtype: Optional[DTypeLike] = ...,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def ones_like(a: Union[ArrayLike, DuckTypedArray],
dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
Expand Down Expand Up @@ -874,7 +881,8 @@ def where(condition: ArrayLike, x: Optional[ArrayLike] = ...,
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
) -> Union[Array, tuple[Array, ...]]: ...

def zeros(shape: Any, dtype: Optional[DTypeLike] = ...) -> Array: ...
def zeros(shape: Any, dtype: Optional[DTypeLike] = ...,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def zeros_like(a: Union[ArrayLike, DuckTypedArray],
dtype: Optional[DTypeLike] = ...,
shape: Any = ...) -> Array: ...
Expand Down
27 changes: 25 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import jax.ops
from jax import lax
from jax import numpy as jnp
from jax.sharding import SingleDeviceSharding
from jax import tree_util
from jax.test_util import check_grads

Expand Down Expand Up @@ -2827,6 +2828,28 @@ def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full],
shape=array_shapes,
dtype=default_dtypes,
)
def testArrayCreationWithDevice(self, func, shape, dtype):
device = jax.devices()[-1]
kwds = {'fill_value': 1} if func is jnp.full else {}
out = func(**kwds, shape=shape, dtype=dtype, device=device)
self.assertEqual(out.devices(), {device})

@jtu.sample_product(
func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full],
shape=array_shapes,
dtype=default_dtypes,
)
def testArrayCreationWithSharding(self, func, shape, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
kwds = {'fill_value': 1} if func is jnp.full else {}
out = func(**kwds, shape=shape, dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)

def testDuckTypedLike(self):
x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32"))
self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype))
Expand Down Expand Up @@ -5650,7 +5673,7 @@ def testWrappedSignaturesMatch(self):
'hstack': ['casting'],
'identity': ['like'],
'isin': ['kind'],
'full': ['device', 'order', 'like'],
'full': ['order', 'like'],
'full_like': ['device', 'subok', 'order'],
'fromfunction': ['like'],
'histogram': ['normed'],
Expand All @@ -5661,7 +5684,7 @@ def testWrappedSignaturesMatch(self):
'nanquantile': ['weights'],
'nanstd': ['correction', 'mean'],
'nanvar': ['correction', 'mean'],
'ones': ['device', 'order', 'like'],
'ones': ['order', 'like'],
'ones_like': ['device', 'subok', 'order'],
'partition': ['kind', 'order'],
'percentile': ['weights'],
Expand Down

0 comments on commit 8b81555

Please sign in to comment.