diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 38c6187f03a9..cca4306add0c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1487,17 +1487,9 @@ def iota(dtype: DType, size: int) -> Array: `_ operator. """ - if config.omnistaging_enabled: - dtype = dtypes.canonicalize_dtype(dtype) - size = core.concrete_or_error(int, size, "size argument of lax.iota") - return iota_p.bind(dtype=dtype, shape=(size,), dimension=0) - else: - size = size if type(size) is masking.Poly else int(size) - shape = canonicalize_shape((size,)) - dtype = dtypes.canonicalize_dtype(dtype) - lazy_expr = lazy.iota(dtype, shape[0]) - aval = ShapedArray(shape, dtype) - return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) + dtype = dtypes.canonicalize_dtype(dtype) + size = core.concrete_or_error(int, size, "size argument of lax.iota") + return iota_p.bind(dtype=dtype, shape=(size,), dimension=0) def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array: """Convenience wrapper around ``iota``.""" @@ -1512,14 +1504,9 @@ def _eye(dtype: DType, shape: Shape, offset: int) -> Array: N, M = tuple(map(int, shape)) offset = int(offset) dtype = dtypes.canonicalize_dtype(dtype) - if config.omnistaging_enabled: - bool_eye = eq(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)), - broadcasted_iota(np.int32, (N, M), 1)) - return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False) - else: - lazy_expr = lazy.eye(dtype, (N, M), offset) - aval = ShapedArray((N, M), dtype) - return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) + bool_eye = eq(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)), + broadcasted_iota(np.int32, (N, M), 1)) + return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False) def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array: """This utility function exists for creating Kronecker delta arrays.""" @@ -1527,30 +1514,20 @@ def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array: axes = tuple(map(int, axes)) dtype = dtypes.canonicalize_dtype(dtype) base_shape = tuple(np.take(shape, axes)) # type: ignore[arg-type] - if config.omnistaging_enabled: - iotas = [broadcasted_iota(np.uint32, base_shape, i) - for i in range(len(base_shape))] - eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])] - result = convert_element_type_p.bind(_reduce(operator.and_, eyes), new_dtype=dtype, weak_type=False) - return broadcast_in_dim(result, shape, axes) - else: - lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes) - aval = ShapedArray(shape, dtype) - return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) + iotas = [broadcasted_iota(np.uint32, base_shape, i) + for i in range(len(base_shape))] + eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])] + result = convert_element_type_p.bind(_reduce(operator.and_, eyes), new_dtype=dtype, weak_type=False) + return broadcast_in_dim(result, shape, axes) def _tri(dtype: DType, shape: Shape, offset: int) -> Array: """Like numpy.tri, create a 2D array with ones below a diagonal.""" N, M = tuple(map(int, shape)) offset = int(offset) dtype = dtypes.canonicalize_dtype(dtype) - if config.omnistaging_enabled: - bool_tri = ge(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)), - broadcasted_iota(np.int32, (N, M), 1)) - return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False) - else: - lazy_expr = lazy.tri(dtype, (N, M), offset) - aval = ShapedArray((N, M), dtype) - return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) + bool_tri = ge(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)), + broadcasted_iota(np.int32, (N, M), 1)) + return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False) def stop_gradient(x): """Stops gradient computation. diff --git a/jax/api.py b/jax/api.py index a516fe635688..f391bcae631d 100644 --- a/jax/api.py +++ b/jax/api.py @@ -311,7 +311,7 @@ def cache_miss(_, *args, **kwargs): # has been reset to None). Thus, we do not support the fast-path. execute is not None and execute.func is xla._execute_compiled and # not trivial, not pmap - # Not supported: ShardedDeviceArray, DeviceConstant. + # Not supported: ShardedDeviceArray all(xla.type_is_device_array(x) for x in out_flat) and # TODO(mattjj): Add support for lazy-expression. # If the input is a DeviceArray, then it should have a trivial LazyExpr. diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index c7437c9016bd..df0e4b32bc48 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -1034,15 +1034,14 @@ def make_device_array( aval: core.ShapedArray, device: Optional[Device], lazy_expr: Optional[lazy.LazyExpr], - device_buffer: Union[PyLocalBuffer, "DeviceConstant"], + device_buffer: PyLocalBuffer, ) -> Union[PyLocalBuffer, "_DeviceArray"]: """Returns a DeviceArray implementation based on arguments. This is to be used only within JAX. It will return either a PythonDeviceArray or a C++ equivalent implementation. """ - if (_EXPERIMENTAL_CPP_DEVICE_ARRAY and lazy.is_trivial(lazy_expr) and - not isinstance(device_buffer, DeviceConstant)): + if _EXPERIMENTAL_CPP_DEVICE_ARRAY and lazy.is_trivial(lazy_expr): assert isinstance(device_buffer, _CppDeviceArray) device_buffer._device = device # pylint: disable=protected-access device_buffer.aval = aval @@ -1120,10 +1119,7 @@ def block_until_ready(self): def _value(self): self._check_if_deleted() if self._npy_value is None: - if is_device_constant(self): - self._npy_value = lazy.eval_lexpr(self._lazy_expr, None) - else: - self._npy_value = _force(self).device_buffer.to_py() + self._npy_value = _force(self).device_buffer.to_py() self._npy_value.flags.writeable = False return self._npy_value @@ -1146,7 +1142,7 @@ def ndim(self): def copy_to_host_async(self): """Requests a copy of the buffer to the host.""" self._check_if_deleted() - if self._npy_value is None and not is_device_constant(self): + if self._npy_value is None: self.device_buffer.copy_to_host_async() # pytype: disable=attribute-error def delete(self): @@ -1282,14 +1278,6 @@ def raise_not_implemented(): class DeletedBuffer(object): pass deleted_buffer = DeletedBuffer() -class DeviceConstant(object): - __slots__ = ["_device"] - def __init__(self, device=None): self._device = device - def device(self): return self._device - def to_py(self): return None - -def is_device_constant(x): - return type_is_device_array(x) and type(x.device_buffer) is DeviceConstant for device_array in [_CppDeviceArray, _DeviceArray]: core.literalable_types.add(device_array) @@ -1298,11 +1286,8 @@ def is_device_constant(x): canonicalize_dtype_handlers[device_array] = identity def _device_array_constant_handler(c, val, canonicalize_types=True): - if is_device_constant(val): - return lazy.stage_lexpr(c, val._lazy_expr, None) - else: - base_val = xb.constant(c, val.device_buffer.to_py()) - return lazy.stage_lexpr(c, val._lazy_expr, base_val) + base_val = xb.constant(c, val.device_buffer.to_py()) + return lazy.stage_lexpr(c, val._lazy_expr, base_val) xb.register_constant_handler(_DeviceArray, _device_array_constant_handler) xb.register_constant_handler(_CppDeviceArray, _device_array_constant_handler) @@ -1316,10 +1301,6 @@ def _copy_device_array_to_device(x: Union[DeviceArrayProtocol, _DeviceArray], de if device is None: # no copying to be done because there's no target specified return x - elif is_device_constant(x): - # create a new DeviceArray with the same lazy expr, no copying - return make_device_array(x.aval, device, x._lazy_expr, - DeviceConstant(device)) elif xb.get_device_backend(device).platform == x.device_buffer.platform(): # source and target platforms are the same if x.device_buffer.device() == device: @@ -1355,14 +1336,11 @@ def _lazy_force_computation(aval: core.ShapedArray, device: Device, lexpr: lazy.LazyExpr ) -> Callable[[_DeviceArray], PyLocalBuffer]: c = xb.make_computation_builder("lazy_force") - if lazy.is_constant(lexpr): - param = None - else: - idxs = [(src, dst) for dst, src in enumerate(lexpr.dims) if src is not None] - param_shape = [None] * len(idxs) - for src, dst in idxs: - param_shape[src] = aval.shape[dst] - param = xb.parameter(c, 0, xc.Shape.array_shape(aval.dtype, param_shape)) + idxs = [(src, dst) for dst, src in enumerate(lexpr.dims) if src is not None] + param_shape = [None] * len(idxs) + for src, dst in idxs: + param_shape[src] = aval.shape[dst] + param = xb.parameter(c, 0, xc.Shape.array_shape(aval.dtype, param_shape)) xla_out = lazy.stage_lexpr(c, lexpr, param) built_c = c.build(xla_out) @@ -1373,13 +1351,8 @@ def _lazy_force_computation(aval: core.ShapedArray, device_assignment=device and (device.id,)) compiled = backend_compile(xb.get_device_backend(device), built_c, options) - force_fun: Callable[[_DeviceArray], PyLocalBuffer] - if lazy.is_constant(lexpr): - def force_fun(_): - return compiled.execute([])[0] - else: - def force_fun(x): - return compiled.execute([x.device_buffer])[0] + def force_fun(x: _DeviceArray) -> PyLocalBuffer: + return compiled.execute([x.device_buffer])[0] return force_fun diff --git a/jax/lazy.py b/jax/lazy.py index 162a620f11b9..66f3ce4cb42f 100644 --- a/jax/lazy.py +++ b/jax/lazy.py @@ -14,14 +14,11 @@ from collections import namedtuple -import functools -import operator as op from typing import Optional, Sequence import numpy as np -from ._src.util import safe_map, safe_zip, unzip2, subvals, taggedtuple -from .lib import xla_bridge as xb +from ._src.util import safe_map, safe_zip, unzip2 from .lib import xla_client as xc from ._src import traceback_util @@ -35,26 +32,9 @@ ### lazy sublanguage -# There are two components to a LazyExpr: an input and a reindexing -# specification. The input represents a base array to which the reindexing -# specification is applied. -# -# An input can represent an array constructor (Iota, Eye, etc.) or it can be an -# ArrayVar which encodes that the base array is some exogenous array value (from -# an environment with only a single value in it). These LazyExprs are attached -# to DeviceArrays, so when the input part of the expression is ArrayVar that -# basically means the associated device buffer represents the input, while if -# the input is an array constructor then the associated device_buffer field of -# the DeviceArray should be set to a DeviceConstant sentinel value. For the -# array constructor expressions: -# * Iota builds a 1D sequence [0, 1, ..., N-1], -# * Eye builds a 2D array with ones on a (possibly offset) diagonal and zeros -# elsewhere (like numpy.eye), -# * Tri builds a triangular matrix with ones on and below a diagonal and zeros -# elsewhere (like numpy.tri), and -# * Delta builds a Kronecker delta array with ones along its multidimensional -# main diagonal and zeros elsewhere (for use in tensor contractions). -# +# A LazyExpr contains a reindexing specification. The reindexing expression is +# applied to the value represented by the device buffer backing a DeviceArray. + # The reindexing specification encodes the shape of the final result and a list # of dimensions, which are integers or Nones. The integer entries take on values # 0, 1, ..., R-1 where R is the rank of the input array, and encode where the @@ -62,21 +42,6 @@ # None that indicates that the corresponding axis of the result is a broadcasted # one. # -# Here are some examples of lazy expressions and the arrays they represent: -# -# LazyExpr(input=Iota(dtype=dtype('float32'), size=3), -# shape=(3, 4), dims=(0, None)) -# DeviceArray([[0., 0., 0., 0.], -# [1., 1., 1., 1.], -# [2., 2., 2., 2.]], dtype=float32) -# -# LazyExpr(input=Iota(dtype=dtype('float32'), size=3), -# shape=(4, 3), dims=(None, 0)) -# DeviceArray([[0., 1., 2.], -# [0., 1., 2.], -# [0., 1., 2.], -# [0., 1., 2.]], dtype=float32) -# # For performance, some functions on lazy expressions accept None as an input to # stand for the identity lazy expression. # @@ -88,49 +53,32 @@ # hash(A(1, 2)) == hash(B(1, 2)) # True # but we want hashes to be sensitive to the type tag (while still being fast). +# TODO(phawkins): remove `input` from LazyExpr when jaxlib 0.1.63 is the minimum + # pytype: disable=wrong-arg-count LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims']) -ArrayVar = taggedtuple('ArrayVar', []) -Iota = taggedtuple('Iota', ['dtype', 'size']) # like np.arange(N) -Eye = taggedtuple('Eye', ['dtype', 'shape', 'offset']) # like np.eye -Tri = taggedtuple('Tri', ['dtype', 'shape', 'offset']) # like np.tri -Delta = taggedtuple('Delta', ['dtype', 'shape']) # kronecker delta arrays # pytype: enable=wrong-arg-count -def array(shape): - return LazyExpr(ArrayVar(), shape, tuple(range(len(shape)))) - -def iota(dtype, size): - return LazyExpr(Iota(dtype, size), (size,), (0,)) - -def eye(dtype, shape, offset): - assert len(shape) == 2 - return LazyExpr(Eye(dtype, shape, offset), shape, (0, 1)) - -def tri(dtype, shape, offset): - assert len(shape) == 2 - return LazyExpr(Tri(dtype, shape, offset), shape, (0, 1)) +# TODO(phawkins): remove `ArrayVar` when jaxlib 0.1.63 is the minimum. +class ArrayVar: + pass -def delta(dtype, shape): - return LazyExpr(Delta(dtype, shape), shape, tuple(range(len(shape)))) +def array(shape): + return LazyExpr(None, shape, tuple(range(len(shape)))) def broadcast(lexpr, shape, broadcast_dimensions): new_dims = [None] * len(shape) for i, d in enumerate(broadcast_dimensions): new_dims[d] = lexpr.dims[i] - return LazyExpr(lexpr.input, shape, tuple(new_dims)) + return LazyExpr(None, shape, tuple(new_dims)) def transpose(lexpr: LazyExpr, perm: Sequence[int]): new_shape = tuple(lexpr.shape[i] for i in perm) new_dims = tuple(lexpr.dims[i] for i in perm) - return LazyExpr(lexpr.input, new_shape, new_dims) - -def is_constant(lexpr: Optional[LazyExpr]): - return lexpr is not None and type(lexpr.input) is not ArrayVar + return LazyExpr(None, new_shape, new_dims) def is_trivial(lexpr: Optional[LazyExpr]) -> bool: - return lexpr is None or (type(lexpr.input) is ArrayVar and - lexpr.dims == tuple(range(len(lexpr.shape)))) + return lexpr is None or (lexpr.dims == tuple(range(len(lexpr.shape)))) def eval_lexpr(lexpr, x): @@ -141,36 +89,12 @@ def eval_lexpr(lexpr, x): Returns: An ndarray representing the value of the lazy expression. """ - if lexpr is None or is_trivial(lexpr): + if is_trivial(lexpr): return x - input_, shape, dims = lexpr - - # first create a starting ndarray from input_ - t = type(input_) - if t is ArrayVar: - assert x is not None and type(x) is np.ndarray - elif t is Iota: - assert x is None - x = np.arange(input_.size, dtype=input_.dtype) - elif t is Eye: - assert x is None - N, M = input_.shape - x = np.eye(N, M, dtype=input_.dtype, k=input_.offset) - elif t is Tri: - assert x is None - N, M = input_.shape - x = np.tri(N, M, dtype=input_.dtype, k=input_.offset) - elif t is Delta: - ones = [1] * len(input_.shape) - iotas = [np.arange(d).reshape(subvals(ones, [(i, -1)])) - for i, d in enumerate(input_.shape)] - eyes = [i1 == i2 for i1, i2 in zip(iotas[:-1], iotas[1:])] - x = np.asarray(functools.reduce(op.and_, eyes), input_.dtype) - else: - assert False - - # then apply the reindexing operation + assert x is not None + _, shape, dims = lexpr + perm = [d for d in dims if d is not None] if perm != list(range(len(perm))): x = np.transpose(x, perm) @@ -181,7 +105,7 @@ def eval_lexpr(lexpr, x): return x -def stage_lexpr(c, lexpr: Optional[LazyExpr], x): +def stage_lexpr(c, lexpr, x): """Stage a lazy expression into an XLA computation. Args: c: XLA ComputationBuilder into which to stage the expression. @@ -190,46 +114,11 @@ def stage_lexpr(c, lexpr: Optional[LazyExpr], x): Returns: An XlaOp representing the value of the lazy expression. """ - if lexpr is None or is_trivial(lexpr): + if is_trivial(lexpr): return x - input_, shape, dims = lexpr - - # first create a starting XlaOp from input_ - t = type(input_) - if t is ArrayVar: - assert x is not None - elif t is Iota: - assert x is None - x = xops.Iota(c, xb.dtype_to_etype(input_.dtype), input_.size) - elif t is Eye: - assert x is None - N, M = input_.shape - xla_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, (N, M)) - bool_eye = xops.Eq( - xops.Add(xops.Iota(c, xla_shape, 0), - xb.constant(c, np.array(input_.offset, np.int32))), - xops.Iota(c, xla_shape, 1)) - x = xops.ConvertElementType(bool_eye, xb.dtype_to_etype(input_.dtype)) - elif t is Tri: - assert x is None - N, M = input_.shape - xla_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, (N, M)) - bool_tri = xops.Ge( - xops.Add(xops.Iota(c, xla_shape, 0), - xb.constant(c, np.array(input_.offset, np.int32))), - xops.Iota(c, xla_shape, 1)) - x = xops.ConvertElementType(bool_tri, xb.dtype_to_etype(input_.dtype)) - elif t is Delta: - etype = xb.dtype_to_etype(input_.dtype) - iotas = [xops.Iota(c, xc.Shape.array_shape(xc.PrimitiveType.U32, input_.shape), i) - for i in range(len(input_.shape))] - eyes = [xops.Eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])] - x = xops.ConvertElementType(functools.reduce(xops.And, eyes), etype) - else: - assert False - - # then apply the operations encoded in reindex + assert x is not None + _, shape, dims = lexpr bcast_dims, perm = unzip2((i, d) for i, d in enumerate(dims) if d is not None) if tuple(perm) != tuple(range(len(perm))): x = xops.Transpose(x, perm) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 94428c5b5e89..15afe4af0db8 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4150,11 +4150,6 @@ def testArange(self): self.assertEqual(type(jnp.arange(77, dtype=jnp.int32)), type(lax.iota(np.int32, 77))) - # test laziness for int dtypes - if not config.omnistaging_enabled: - self.assertTrue(xla.is_device_constant(jnp.arange(77))) - self.assertTrue(xla.is_device_constant(jnp.arange(77, dtype=jnp.int32))) - def testArangeJit(self): ans = api.jit(lambda: jnp.arange(5))() expected = np.arange(5) diff --git a/tests/random_test.py b/tests/random_test.py index 18730ad58f38..d1174c437a01 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -901,6 +901,8 @@ def testChoiceShapeIsNotSequenceError(self): random.choice(key, 5, 2, replace=True) def test_eval_shape_big_random_array(self): + if not config.omnistaging_enabled: + raise SkipTest("after deleting lazy constants, requires omnistaging") def f(x): return random.normal(random.PRNGKey(x), (int(1e12),)) with core.skipping_checks(): # check_jaxpr will materialize array