diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 94d0a6f55d3e..049e50486887 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -30,6 +30,7 @@ from ..config import flags from .. import core from .. import linear_util as lu +from .. import lazy from ..abstract_arrays import (ConcreteArray, ShapedArray, array_types, raise_to_shaped) from ..util import partial, unzip2, concatenate, prod, safe_map @@ -360,8 +361,7 @@ def __getitem__(self, idx): ids = self._ids() device_buffer = self.device_buffers[ids[idx]] aval = ShapedArray(self.aval.shape[1:], self.aval.dtype) - handler = xla.aval_to_result_handler(None, aval) - return handler(device_buffer) + return xla.DeviceArray(aval, None, lazy.array(aval.shape), device_buffer) else: return super(ShardedDeviceArray, self).__getitem__(idx) @@ -376,11 +376,14 @@ def _shard_sharded_device_array(x, devices, assignments): return (xla.device_put(x[assignments[r]], devices[r]) for r in range(n)) shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array +def _sharded_device_array_constant_handler(c, val, canonicalize_types=True): + return c.Constant(onp.asarray(val), canonicalize_types=canonicalize_types) +xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler) + core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray xla.device_put_handlers[ShardedDeviceArray] = xla._device_put_array -xla.pytype_aval_mappings[ShardedDeviceArray] = lambda x: x.aval +xla.pytype_aval_mappings[ShardedDeviceArray] = op.attrgetter('aval') xla.canonicalize_dtype_handlers[ShardedDeviceArray] = identity -xb.register_constant_handler(ShardedDeviceArray, xla._device_array_constant_handler) class ChunkedDeviceArray(ShardedDeviceArray): @@ -398,10 +401,8 @@ def __getitem__(self, idx): core.pytype_aval_mappings[ChunkedDeviceArray] = ConcreteArray xla.device_put_handlers[ChunkedDeviceArray] = xla._device_put_array -xla.pytype_aval_mappings[ChunkedDeviceArray] = lambda x: x.aval +xla.pytype_aval_mappings[ChunkedDeviceArray] = op.attrgetter('aval') xla.canonicalize_dtype_handlers[ChunkedDeviceArray] = identity -xb.register_constant_handler(ChunkedDeviceArray, - xla._device_array_constant_handler) ### the xla_pmap primitive and its rules are comparable to xla_call in xla.py diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 76523d5abff5..852920727a17 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -32,6 +32,7 @@ from .. import ad_util from .. import tree_util from .. import dtypes +from .. import lazy from .. import linear_util as lu from ..abstract_arrays import (ConcreteArray, ShapedArray, AbstractToken, make_shaped_array, array_types, raise_to_shaped, @@ -78,7 +79,7 @@ def aval_to_result_handler(device, aval): xla_result_handlers = {} xla_result_handlers[core.AbstractUnit] = lambda _, __: lambda _: core.unit def array_result_handler(device, aval): - return partial(DeviceArray, raise_to_shaped(aval), device) + return partial(DeviceArray, raise_to_shaped(aval), device, lazy.array(aval.shape)) xla_result_handlers[ShapedArray] = array_result_handler xla_result_handlers[ConcreteArray] = array_result_handler @@ -146,6 +147,7 @@ def _make_abstract_python_scalar(typ, _): for _t in dtypes.python_scalar_dtypes.keys(): pytype_aval_mappings[_t] = partial(_make_abstract_python_scalar, _t) + ### op-by-op execution def arg_spec(x): @@ -536,6 +538,7 @@ def _xla_callable_args(c, avals, tuple_args): def _pval_to_result_handler(device, pval): pv, const = pval if pv is None: + const = _device_put_impl(const, device) if device else const return lambda _: const else: return aval_to_result_handler(device, pv) @@ -566,8 +569,8 @@ def _execute_trivial(jaxpr, device, consts, handlers, *args): _map(env.setdefault, jaxpr.constvars, consts) outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v] for v in jaxpr.outvars] - return [x if type(x) is DeviceArray else handler(device_put(x, device)) - for handler, x in zip(handlers, outs)] + return [_copy_device_array_to_device(x, device) if type(x) is DeviceArray + else h(device_put(x, device)) for h, x in zip(handlers, outs)] def make_tuple(bufs, device, backend): return xb.get_backend(backend).make_tuple(bufs, device) @@ -678,7 +681,7 @@ def __init__(self, aval, device_buffer): self.device_buffer = device_buffer def _check_if_deleted(self): - if self.device_buffer is None: + if self.device_buffer is deleted_buffer: raise ValueError("DeviceValue has been deleted.") def block_until_ready(self): @@ -702,13 +705,14 @@ class DeviceArray(DeviceValue): """A DeviceArray is an ndarray backed by a single device memory buffer.""" # We don't subclass ndarray because that would open up a host of issues, # but lax_numpy.py overrides isinstance behavior and attaches ndarray methods. - __slots__ = ["_npy_value", "_device"] + __slots__ = ["_npy_value", "_device", "_lazy_expr"] __array_priority__ = 100 - def __init__(self, aval, device, device_buffer): + def __init__(self, aval, device, lazy_expr, device_buffer): self.aval = aval self.device_buffer = device_buffer self._device = device and (type(device), device.id) + self._lazy_expr = lazy_expr self._npy_value = None if not core.skip_checks: @@ -720,7 +724,10 @@ def __init__(self, aval, device, device_buffer): def _value(self): self._check_if_deleted() if self._npy_value is None: - self._npy_value = self.device_buffer.to_py() + if is_device_constant(self): + self._npy_value = lazy.eval_lexpr(self._lazy_expr, None) + else: + self._npy_value = _force(self).device_buffer.to_py() self._npy_value.flags.writeable = False return self._npy_value @@ -747,7 +754,7 @@ def copy(self): def copy_to_host_async(self): """Requests a copy of the buffer to the host.""" self._check_if_deleted() - if self._npy_value is None: + if self._npy_value is None and not is_device_constant(self): self.device_buffer.copy_to_host_async() def delete(self): @@ -762,7 +769,7 @@ def delete(self): time of deletion. """ self.device_buffer.delete() - self.device_buffer = None + self.device_buffer = deleted_buffer self._npy_value = None def __repr__(self): @@ -837,31 +844,97 @@ def __eq__(self, other): return self._value == other def __hash__(self): raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.") +class DeletedBuffer(object): pass +deleted_buffer = DeletedBuffer() + +class DeviceConstant(object): + __slots__ = ["_device"] + def __init__(self, device=None): self._device = device + def device(self): return self._device + def to_py(self): return None + +def is_device_constant(x): + return type(x) is DeviceArray and type(x.device_buffer) is DeviceConstant + core.literalable_types.add(DeviceArray) core.pytype_aval_mappings[DeviceArray] = ConcreteArray -pytype_aval_mappings[DeviceArray] = lambda x: x.aval +pytype_aval_mappings[DeviceArray] = op.attrgetter('aval') canonicalize_dtype_handlers[DeviceArray] = identity def _device_array_constant_handler(c, val, canonicalize_types=True): - return c.Constant(onp.asarray(val), canonicalize_types=canonicalize_types) + if is_device_constant(val): + return lazy.stage_lexpr(c, val._lazy_expr, None) + else: + base_val = c.Constant(val.device_buffer.to_py()) + return lazy.stage_lexpr(c, val._lazy_expr, base_val) xb.register_constant_handler(DeviceArray, _device_array_constant_handler) def _device_put_device_array(x, device): - # TODO(skye): we're assuming the DeviceBuffers without "platform" are - # XrtBuffers. Figure out a less risky way to deal with XrtBuffers. - if (not hasattr(x.device_buffer, "platform") or - xb.get_device_backend(device).platform == x.device_buffer.platform()): + x = _copy_device_array_to_device(x, device) + return _force(x).device_buffer +device_put_handlers[DeviceArray] = _device_put_device_array + +def _copy_device_array_to_device(x, device): + if is_device_constant(x): + return DeviceArray(x.aval, device, x._lazy_expr, DeviceConstant(device)) + elif xb.get_device_backend(device).platform == x.device_buffer.platform(): if device is None or x.device_buffer.device() == device: - return x.device_buffer + return x else: - return x.device_buffer.copy_to_device(device) + moved_buf = x.device_buffer.copy_to_device(device) else: - # Buffers from different XLA backends are passed through the host. - return xc.Buffer.from_pyval(x, device, backend=xb.get_device_backend(device)) -device_put_handlers[DeviceArray] = _device_put_device_array + # Buffers from different XLA backends are passed through the host. + moved_buf = xc.Buffer.from_pyval(x.device_buffer.to_py(), device, + backend=xb.get_device_backend(device)) + return DeviceArray(x.aval, device, x._lazy_expr, moved_buf) + +def _force(x): + if lazy.is_trivial(x._lazy_expr): + return x + else: + # force x on the device where it lives, but preserve stickiness on result + if x._device: + device = x._device + sticky = True + else: + d = x.device_buffer.device() + device = d and (type(d), d.id) + sticky = False + force_fun = _lazy_force_computation(sticky, x.aval, device, x._lazy_expr) + return force_fun(x) + +@cache() +def _lazy_force_computation(sticky, aval, device, lexpr): + c = xb.make_computation_builder("lazy_force") + if lazy.is_constant(lexpr): + param = None + else: + idxs = [(src, dst) for dst, src in enumerate(lexpr.dims) if src is not None] + param_shape = [None] * len(idxs) + for src, dst in idxs: + param_shape[src] = aval.shape[dst] + param = c.ParameterWithShape(xc.Shape.array_shape(aval.dtype, param_shape)) + xla_out = lazy.stage_lexpr(c, lexpr, param) + built_c = c.Build(xla_out) + + device = _device_from_arg_devices([device]) + options = xb.get_compile_options(device_assignment=device and (device.id,)) + backend = xb.get_device_backend(device) + compiled = built_c.Compile(compile_options=options, backend=backend) + + result_device = device if sticky else None + handler = partial(DeviceArray, aval, result_device, lazy.array(aval.shape)) + if lazy.is_constant(lexpr): + force_fun = lambda _: handler(compiled.Execute([])) + else: + force_fun = lambda x: handler(compiled.Execute([x.device_buffer])) + return force_fun def _device_put_impl(x, device=None): + if type(x) is DeviceArray: + return _copy_device_array_to_device(x, device) + try: a = abstractify(x) except TypeError: @@ -904,28 +977,3 @@ def _foil_cse(c, x): shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype() zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape) return c.Select(pred, x, zero) - - -### lazy constants - -class DeviceConstant(DeviceArray): - def copy_to_host_async(self): pass - - @staticmethod - def constant_handler(c, constant_instance, canonicalize_types=True): - assert False - -def _instantiate_device_constant(const, device=None, backend=None, cutoff=1e6): - # dispatch an XLA Computation to build the constant on the device if it's - # large, or alternatively build it on the host and transfer it if it's small - assert isinstance(const, DeviceConstant) - backend = xb.get_backend(device.platform) if device else xb.get_backend(backend) - if const.size > cutoff: - c = xb.make_computation_builder("constant_instantiating_computation") - xla_const = const.constant_handler(c, const) - device_assignment = (device.id,) if device else None - opts = xb.get_compile_options(device_assignment=device_assignment) - compiled = c.Build(xla_const).Compile((), opts, backend=backend) - return compiled.Execute(()) - else: - return xc.Buffer.from_pyval(onp.asarray(const), device, backend=backend) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index c40c3531b233..8a7192a91182 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -21,7 +21,7 @@ _input_dtype, _const, _eq_meet, _safe_mul, _broadcasting_select, _check_user_dtype_supported, _one, _const, _upcast_fp16_for_computation, - _broadcasting_shape_rule) + _broadcasting_shape_rule, _eye, _tri, _delta) from .lax_control_flow import * from .lax_fft import * from .lax_parallel import * diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 38e4d627e35c..dc3789df40e3 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -36,6 +36,7 @@ from .. import api from .. import linear_util as lu from .. import dtypes +from .. import lazy from ..config import flags from ..core import Primitive from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray, @@ -615,15 +616,12 @@ def broadcast(operand, sizes): Returns: An array containing the result. """ - return broadcast_p.bind(operand, sizes=tuple(sizes)) + dims = tuple(range(len(sizes), len(sizes) + onp.ndim(operand))) + return broadcast_in_dim(operand, tuple(sizes) + onp.shape(operand), dims) def broadcast_in_dim(operand, shape, broadcast_dimensions): if onp.ndim(operand) == len(shape) and not len(broadcast_dimensions): return operand - if any(x < 0 or x >= len(shape) for x in broadcast_dimensions): - msg = ("broadcast dimensions must be >= 0 and < ndim(shape), got {} for " - "shape {}") - raise ValueError(msg.format(broadcast_dimensions, shape)) return broadcast_in_dim_p.bind( operand, shape=tuple(shape), broadcast_dimensions=tuple(broadcast_dimensions)) @@ -1060,47 +1058,64 @@ def full(shape, fill_value, dtype=None): if onp.shape(fill_value): msg = "full must be called with scalar fill_value, got fill_value.shape {}." raise TypeError(msg.format(onp.shape(fill_value))) - dtype = dtype or _dtype(fill_value) - dtype = dtypes.canonicalize_dtype(dtype) - - # For constants (defined as Python scalars, raw ndarrays, or DeviceValues), - # create a _FilledConstant value, otherwise just call broadcast. - if onp.isscalar(fill_value) or type(fill_value) is onp.ndarray: - return _FilledConstant(onp.asarray(fill_value, dtype), shape) - elif isinstance(fill_value, xla.DeviceValue): - val = onp.asarray(fill_value, dtype) - return _FilledConstant(val, shape) - else: - return broadcast(convert_element_type(fill_value, dtype), shape) + dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) + # TODO(mattjj): remove device_put when dtype conversion produces DeviceArray + fill_value = xla.device_put_p.bind(convert_element_type(fill_value, dtype)) + return broadcast(fill_value, shape) def iota(dtype, size): """Wraps XLA's `Iota `_ operator. """ - return broadcasted_iota(dtype, (int(size),), 0) + size = int(size) + dtype = dtypes.canonicalize_dtype(dtype) + lazy_expr = lazy.iota(dtype, size) + aval = ShapedArray((size,), dtype) + return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) def broadcasted_iota(dtype, shape, dimension): - """Wraps XLA's `Iota - `_ - operator. - """ + """Convenience wrapper around ``iota``.""" dtype = dtypes.canonicalize_dtype(dtype) shape = _canonicalize_shape(shape) dimension = int(dimension) - return _IotaConstant(dtype, shape, dimension) + return broadcast_in_dim(iota(dtype, shape[dimension]), shape, [dimension]) -def eye(dtype, size): - return broadcasted_eye(dtype, (size, size), (0, 1)) +def _eye(dtype, shape, offset): + """Like numpy.eye, create a 2D array with ones on a diagonal. -def broadcasted_eye(dtype, shape, axes): - if not isinstance(axes, (list, tuple)) or not len(axes) >= 2: - raise TypeError("make_diagonal `axes` must be a tuple with len at least 2.") + This function exists for creating lazy identity matrices; that is, + materialization of the array is delayed and it may be fused into consumers to + avoid materialization at all.""" + N, M = tuple(map(int, shape)) + offset = int(offset) dtype = dtypes.canonicalize_dtype(dtype) - shape = _canonicalize_shape(shape) + lazy_expr = lazy.eye(dtype, (N, M), offset) + aval = ShapedArray((N, M), dtype) + return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) + +def _delta(dtype, shape, axes): + """This function exists for creating lazy Kronecker delta arrays, particularly + for use in jax.numpy.einsum to express traces. It differs from ``eye`` in that + it can create arrays of any rank, but doesn't allow offsets.""" + shape = tuple(map(int, shape)) axes = tuple(map(int, axes)) - return _EyeConstant(shape, axes, dtype) - + dtype = dtypes.canonicalize_dtype(dtype) + base_shape = tuple(onp.take(shape, axes)) + lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes) + aval = ShapedArray(shape, dtype) + return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) + +def _tri(dtype, shape, offset): + """Like numpy.tri, create a 2D array with ones below a diagonal. + This function exists for creating lazy triangular matrices, particularly for + use in jax.numpy.tri.""" + N, M = tuple(map(int, shape)) + offset = int(offset) + dtype = dtypes.canonicalize_dtype(dtype) + lazy_expr = lazy.tri(dtype, (N, M), offset) + aval = ShapedArray((N, M), dtype) + return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) def stop_gradient(x): """Stops gradient computation. @@ -2288,11 +2303,23 @@ def _broadcast_batch_rule(batched_args, batch_dims, sizes): ad.deflinear(broadcast_p, lambda t, sizes: [_reduce_sum(t, range(len(sizes)))]) batching.primitive_batchers[broadcast_p] = _broadcast_batch_rule +def _broadcast_in_dim_impl(operand, shape, broadcast_dimensions): + if type(operand) is xla.DeviceArray: + aval = ShapedArray(shape, _dtype(operand)) + lazy_expr = lazy.broadcast(operand._lazy_expr, shape, broadcast_dimensions) + return xla.DeviceArray(aval, None, lazy_expr, operand.device_buffer) + else: + return xla.apply_primitive(broadcast_in_dim_p, operand, shape=shape, + broadcast_dimensions=broadcast_dimensions) def _broadcast_in_dim_shape_rule(operand, shape, broadcast_dimensions): _check_shapelike('broadcast_in_dim', 'shape', shape) _check_shapelike('broadcast_in_dim', 'broadcast_dimensions', broadcast_dimensions) + if any(x >= len(shape) for x in broadcast_dimensions): + msg = ("broadcast_in_dim broadcast dimensions must be less than " + "ndim(shape), got {} for shape {}.") + raise ValueError(msg.format(broadcast_dimensions, shape)) if operand.ndim != len(broadcast_dimensions): msg = ('broadcast_in_dim broadcast_dimensions must have length equal to ' 'operand ndim, got broadcast_dimensions {} for operand ndim {}.') @@ -2319,6 +2346,7 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape, broadcast_in_dim_p = standard_primitive( _broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim') +broadcast_in_dim_p.def_impl(_broadcast_in_dim_impl) ad.deflinear(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule) batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule @@ -2466,9 +2494,14 @@ def _pad_batch_rule(batched_args, batch_dims, padding_config): batching.primitive_batchers[pad_p] = _pad_batch_rule -# We have a nonstandard reshape impl so that we can be lazy about data movement -# for specific types, particularly ShardedDeviceArrays / ChunkedDeviceArrays +# We have a nonstandard reshape impl so that we can be lazy about data movement. def _reshape_impl(operand, new_sizes, dimensions, old_sizes): + if type(operand) is xla.DeviceArray and dimensions is None: + bcast_dims = _is_singleton_reshape(old_sizes, new_sizes) + if bcast_dims is not None: + aval = ShapedArray(new_sizes, operand.dtype) + lazy_expr = lazy.broadcast(operand._lazy_expr, new_sizes, bcast_dims) + return xla.DeviceArray(aval, None, lazy_expr, operand.device_buffer) if (type(operand) is pxla.ShardedDeviceArray and dimensions is None and _is_axis_merge(old_sizes, new_sizes)): aval = ShapedArray(new_sizes, operand.dtype) @@ -2482,6 +2515,26 @@ def _reshape_impl(operand, new_sizes, dimensions, old_sizes): return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes, dimensions=dimensions, old_sizes=old_sizes) +def _is_singleton_reshape(old, new): + # A singleton reshape is one where only singleton dimensions are added. We + # want to detect them because they can be expressed as (lazy) broadcasts. + old, new = iter(old), iter(new) + d1, d2 = next(old, None), next(new, None) + bcast_dims = [] + i = 0 + while True: + if d1 is d2 is None: + return bcast_dims + elif d1 == d2: + bcast_dims.append(i) + i += 1 + d1, d2 = next(old, None), next(new, None) + elif d2 == 1: + i += 1 + d2 = next(new, None) + else: + return None + def _is_axis_merge(s1, s2): return s1[2:] == s2[1:] and s1[0] * s1[1] == s2[0] @@ -2560,6 +2613,14 @@ def _rev_batch_rule(batched_args, batch_dims, dimensions): batching.primitive_batchers[rev_p] = _rev_batch_rule +def _transpose_impl(operand, permutation): + if type(operand) is xla.DeviceArray: + lazy_expr = lazy.transpose(operand._lazy_expr, permutation) + aval = ShapedArray(lazy_expr.shape, operand.dtype) + return xla.DeviceArray(aval, None, lazy_expr, operand.device_buffer) + else: + return xla.apply_primitive(transpose_p, operand, permutation=permutation) + def _transpose_shape_rule(operand, permutation): if not isinstance(permutation, (tuple, list, onp.ndarray)): msg = "transpose permutation must be a tuple/list/ndarray, got {}." @@ -2578,6 +2639,7 @@ def _transpose_batch_rule(batched_args, batch_dims, permutation): transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype, 'transpose') +transpose_p.def_impl(_transpose_impl) ad.deflinear(transpose_p, lambda t, permutation: [transpose(t, onp.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule @@ -4037,100 +4099,6 @@ def _tie_in_batch_rule(batched_args, batch_dims): masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1] -### constants - - -class _FilledConstant(xla.DeviceConstant): - __slots__ = ["fill_value"] - - def __init__(self, fill_value, shape): - assert type(fill_value) is onp.ndarray - self.aval = ShapedArray(shape, _dtype(fill_value)) - self._npy_value = None - - self.fill_value = fill_value - - @property - def _value(self): - return onp.full(self.shape, self.fill_value) - - @staticmethod - def constant_handler(c, filled_const, canonicalize_types=True): - return c.Broadcast( - c.NumpyArrayConstant(filled_const.fill_value, canonicalize_types), - filled_const.shape) - - -class _IotaConstant(xla.DeviceConstant): - __slots__ = ["axis"] - - def __init__(self, dtype, shape, axis): - self.aval = ShapedArray(shape, onp.dtype(dtype)) - self._npy_value = None - - self.axis = axis - - @property - def _value(self): - if self._npy_value is None: - iota = onp.arange(self.shape[self.axis], dtype=self.dtype) - iota = iota.reshape([self.shape[self.axis] if i == self.axis else 1 - for i in range(self.ndim)]) - self._npy_value = onp.broadcast_to(iota, self.shape) - return self._npy_value - - @staticmethod - def constant_handler(c, iota_constant, canonicalize_types=True): - dtype = iota_constant.dtype - if canonicalize_types: - dtype = dtypes.canonicalize_dtype(dtype) - return c.BroadcastedIota(dtype, iota_constant.shape, iota_constant.axis) - - -class _EyeConstant(xla.DeviceConstant): - __slots__ = ["axes"] - - def __init__(self, shape, axes, dtype): - self.aval = ShapedArray(shape, onp.dtype(dtype)) - self._npy_value = None - - self.axes = axes - - @property - def _value(self): - if self._npy_value is None: - ones = [1] * self.ndim - iotas = [onp.arange(self.shape[axis]).reshape(subvals(ones, [(axis, -1)])) - for axis in self.axes] - eyes = [i1 == i2 for i1, i2 in zip(iotas[:-1], iotas[1:])] - result = onp.asarray(_reduce(operator.and_, eyes), self.dtype) - self._npy_value = onp.broadcast_to(result, self.shape) - return self._npy_value - - @staticmethod - def constant_handler(c, diag_const, canonicalize_types=True): - if canonicalize_types: - etype = xla_bridge.dtype_to_etype(diag_const.dtype) - else: - etype = xla_client.dtype_to_etype(diag_const.dtype) - etype = xla_bridge.dtype_to_etype(diag_const.dtype) - iotas = [c.BroadcastedIota(onp.uint32, diag_const.shape, axis) - for axis in diag_const.axes] - eyes = [c.Eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])] - return c.ConvertElementType(_reduce(c.And, eyes), etype) - - -for _t in [_FilledConstant, _IotaConstant, _EyeConstant]: - xla_bridge.register_constant_handler(_t, _t.constant_handler) - core.pytype_aval_mappings[_t] = ConcreteArray - xla.pytype_aval_mappings[_t] = make_shaped_array - xla.device_put_handlers[_t] = xla._instantiate_device_constant - pxla.shard_arg_handlers[_t] = pxla._shard_array - xla.canonicalize_dtype_handlers[_t] = _identity - ad_util.jaxval_adders[_t] = add - ad_util.jaxval_zeros_likers[_t] = zeros_like_array - - ### stop-gradient def _stop_gradient_jvp_rule(primals, tangents): @@ -4589,13 +4557,6 @@ def _eq_meet(a, b): return eq(a, b) -def subvals(lst, replace): - lst = list(lst) - for i, v in replace: - lst[i] = v - return tuple(lst) - - def _abstractify(x): return raise_to_shaped(core.get_aval(x)) diff --git a/jax/lazy.py b/jax/lazy.py new file mode 100644 index 000000000000..6c4fcbf5dce4 --- /dev/null +++ b/jax/lazy.py @@ -0,0 +1,244 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import operator as op + +import numpy as onp +from six.moves import reduce + +from .util import safe_map, safe_zip, unzip2, subvals +from .lib import xla_bridge as xb + +map = safe_map +zip = safe_zip + + +### util + +# TODO(mattjj): replace with dataclass when Python 2 support is removed +def taggedtuple(name, fields): + """Lightweight version of namedtuple where equality depends on the type.""" + def __new__(cls, *xs): + return tuple.__new__(cls, (cls,) + xs) + def __str__(self): + return '{}{}'.format(name, tuple.__str__(self[1:])) + class_namespace = {'__new__' : __new__, '__str__': __str__} + for i, f in enumerate(fields): + class_namespace[f] = property(op.itemgetter(i+1)) + return type(name, (tuple,), class_namespace) + + +### lazy sublanguage + +# There are two components to a LazyExpr: an input and a reindexing +# specification. The input represents a base array to which the reindexing +# specification is applied. +# +# An input can represent an array constructor (Iota, Eye, etc.) or it can be an +# ArrayVar which encodes that the base array is some exogenous array value (from +# an environment with only a single value in it). These LazyExprs are attached +# to DeviceArrays, so when the input part of the expression is ArrayVar that +# basically means the associated device buffer represents the input, while if +# the input is an array constructor then the associated device_buffer field of +# the DeviceArray should be set to a DeviceConstant sentinel value. For the +# array constructor expressions: +# * Iota builds a 1D sequence [0, 1, ..., N-1], +# * Eye builds a 2D array with ones on a (possibly offset) diagonal and zeros +# elsewhere (like numpy.eye), +# * Tri builds a triangular matrix with ones on and below a diagonal and zeros +# elsewhere (like numpy.tri), and +# * Delta builds a Kronecker delta array with ones along its multidimensional +# main diagonal and zeros elsewhere (for use in tensor contractions). +# +# The reindexing specification encodes the shape of the final result and a list +# of dimensions, which are integers or Nones. The integer entries take on values +# 0, 1, ..., R-1 where R is the rank of the input array, and encode where the +# axes of the input array are to be mapped in the final output. When an entry is +# None that indicates that the corresponding axis of the result is a broadcasted +# one. +# +# Here are some examples of lazy expressions and the arrays they represent: +# +# LazyExpr(input=Iota(dtype=dtype('float32'), size=3), +# shape=(3, 4), dims=(0, None)) +# DeviceArray([[0., 0., 0., 0.], +# [1., 1., 1., 1.], +# [2., 2., 2., 2.]], dtype=float32) +# +# LazyExpr(input=Iota(dtype=dtype('float32'), size=3), +# shape=(4, 3), dims=(None, 0)) +# DeviceArray([[0., 1., 2.], +# [0., 1., 2.], +# [0., 1., 2.], +# [0., 1., 2.]], dtype=float32) +# +# For performance, some functions on lazy expressions accept None as an input to +# stand for the identity lazy expression. +# +# We use the `taggedtuple` class constructor, rather than standard namedtuples, +# because two namedtuple instances of different types but equal elements hash to +# the same value, e.g. +# A = namedtuple('A', ['x', 'y']) +# B = namedtuple('B', ['x', 'y']) +# hash(A(1, 2)) == hash(B(1, 2)) # True +# but we want hashes to be sensitive to the type tag (while still being fast). + +LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims']) +ArrayVar = taggedtuple('ArrayVar', []) +Iota = taggedtuple('Iota', ['dtype', 'size']) # like np.arange(N) +Eye = taggedtuple('Eye', ['dtype', 'shape', 'offset']) # like np.eye +Tri = taggedtuple('Tri', ['dtype', 'shape', 'offset']) # like np.tri +Delta = taggedtuple('Delta', ['dtype', 'shape']) # kronecker delta arrays + +def array(shape): + return LazyExpr(ArrayVar(), shape, tuple(range(len(shape)))) + +def iota(dtype, size): + return LazyExpr(Iota(dtype, size), (size,), (0,)) + +def eye(dtype, shape, offset): + assert len(shape) == 2 + return LazyExpr(Eye(dtype, shape, offset), shape, (0, 1)) + +def tri(dtype, shape, offset): + assert len(shape) == 2 + return LazyExpr(Tri(dtype, shape, offset), shape, (0, 1)) + +def delta(dtype, shape): + return LazyExpr(Delta(dtype, shape), shape, tuple(range(len(shape)))) + +def broadcast(lexpr, shape, broadcast_dimensions): + new_dims = [None] * len(shape) + for i, d in enumerate(broadcast_dimensions): + new_dims[d] = lexpr.dims[i] + return LazyExpr(lexpr.input, shape, tuple(new_dims)) + +def transpose(lexpr, perm): + new_shape = tuple(lexpr.shape[i] for i in perm) + new_dims = tuple(lexpr.dims[i] for i in perm) + return LazyExpr(lexpr.input, new_shape, new_dims) + +def is_constant(lexpr): + return lexpr is not None and type(lexpr.input) is not ArrayVar + +def is_trivial(lexpr): + return (type(lexpr.input) is ArrayVar and + lexpr.dims == tuple(range(len(lexpr.shape)))) + + +def eval_lexpr(lexpr, x): + """Evaluate a lazy expression using NumPy. + Args: + lexpr: the LazyExpr to evaluate. + x: ndarray or None, representing the value of ArrayVar if present. + Returns: + An ndarray representing the value of the lazy expression. + """ + if is_trivial(lexpr): + return x + + input_, shape, dims = lexpr + + # first create a starting ndarray from input_ + t = type(input_) + if t is ArrayVar: + assert x is not None and type(x) is onp.ndarray + elif t is Iota: + assert x is None + x = onp.arange(input_.size, dtype=input_.dtype) + elif t is Eye: + assert x is None + N, M = input_.shape + x = onp.eye(N, M, dtype=input_.dtype, k=input_.offset) + elif t is Tri: + assert x is None + N, M = input_.shape + x = onp.tri(N, M, dtype=input_.dtype, k=input_.offset) + elif t is Delta: + ones = [1] * len(input_.shape) + iotas = [onp.arange(d).reshape(subvals(ones, [(i, -1)])) + for i, d in enumerate(input_.shape)] + eyes = [i1 == i2 for i1, i2 in zip(iotas[:-1], iotas[1:])] + x = onp.asarray(reduce(op.and_, eyes), input_.dtype) + else: + assert False + + # then apply the reindexing operation + perm = [d for d in dims if d is not None] + if perm != list(range(len(perm))): + x = onp.transpose(x, perm) + if shape != x.shape: + in_shape = [1 if d is None else s for d, s in zip(dims, shape)] + x = onp.broadcast_to(onp.reshape(x, in_shape), shape) + + return x + + +def stage_lexpr(c, lexpr, x): + """Stage a lazy expression into an XLA computation. + Args: + c: XLA ComputationBuilder into which to stage the expression. + lexpr: a LazyExpr to evaluate (or None for the identity expression). + x: XlaOp or None, representing the value of ArrayVar if present. + Returns: + An XlaOp representing the value of the lazy expression. + """ + if lexpr is None or is_trivial(lexpr): + return x + + input_, shape, dims = lexpr + + # first create a starting XlaOp from input_ + t = type(input_) + if t is ArrayVar: + assert x is not None + elif t is Iota: + assert x is None + x = c.Iota(input_.dtype, input_.size) + elif t is Eye: + assert x is None + N, M = input_.shape + bool_eye = c.Eq(c.Add(c.BroadcastedIota(onp.int32, (N, M), 0), + c.Constant(onp.array(input_.offset, onp.int32))), + c.BroadcastedIota(onp.int32, (N, M), 1)) + x = c.ConvertElementType(bool_eye, xb.dtype_to_etype(input_.dtype)) + elif t is Tri: + assert x is None + N, M = input_.shape + bool_tri = c.Ge(c.Add(c.BroadcastedIota(onp.int32, (N, M), 0), + c.Constant(onp.array(input_.offset, onp.int32))), + c.BroadcastedIota(onp.int32, (N, M), 1)) + x = c.ConvertElementType(bool_tri, xb.dtype_to_etype(input_.dtype)) + elif t is Delta: + etype = xb.dtype_to_etype(input_.dtype) + iotas = [c.BroadcastedIota(onp.uint32, input_.shape, i) + for i in range(len(input_.shape))] + eyes = [c.Eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])] + x = c.ConvertElementType(reduce(c.And, eyes), etype) + else: + assert False + + # then apply the operations encoded in reindex + bcast_dims, perm = unzip2((i, d) for i, d in enumerate(dims) if d is not None) + if tuple(perm) != tuple(range(len(perm))): + x = c.Transpose(x, perm) + if shape != c.GetShape(x).dimensions(): + x = c.BroadcastInDim(x, shape, bcast_dims) + + return x diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index d27dd9e4bfb5..793312bfdf88 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -52,7 +52,7 @@ from ..config import flags from ..interpreters.xla import DeviceArray from .. import lax -from ..util import partial, get_module_functions, unzip2, prod as _prod +from ..util import partial, get_module_functions, unzip2, prod as _prod, subvals from ..lib import pytree from ..lib import xla_client @@ -1087,7 +1087,7 @@ def split(ary, indices_or_sections, axis=0): subarrays = onp.split(dummy_val, indices_or_sections, axis) # shapes split_indices = onp.cumsum([0] + [onp.shape(sub)[axis] for sub in subarrays]) starts, ends = [0] * ndim(ary), shape(ary) - _subval = lambda x, i, v: lax.subvals(x, [(i, v)]) + _subval = lambda x, i, v: subvals(x, [(i, v)]) return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) for start, end in zip(split_indices[:-1], split_indices[1:])] @@ -1250,7 +1250,7 @@ def reduction(a, axis=None, dtype=None, out=None, keepdims=False): result = lax.reduce(a, _reduction_init_val(a, init_val), op if computation_dtype != onp.bool_ else bool_op, dims) if keepdims: - shape_with_singletons = lax.subvals(shape(a), zip(dims, (1,) * len(dims))) + shape_with_singletons = subvals(shape(a), zip(dims, (1,) * len(dims))) result = lax.reshape(result, shape_with_singletons) return lax.convert_element_type(result, dtype or result_dtype) @@ -1813,23 +1813,20 @@ def array_equal(a1, a2): @_wraps(onp.eye) -def eye(N, M=None, k=None, dtype=None): +def eye(N, M=None, k=0, dtype=None): lax._check_user_dtype_supported(dtype, "eye") dtype = float_ if dtype is None else dtype M = N if M is None else M + k = int(k) if N < 0 or M < 0: msg = "negative dimensions are not allowed, got {} and {}" raise ValueError(msg.format(N, M)) - if k is None: - return lax.broadcasted_eye(dtype, (N, M), (0, 1)) - else: + if k is not None: k_dtype = _dtype(k) if not issubdtype(k_dtype, integer): msg = "eye argument `k` must be of integer dtype, got {}" raise TypeError(msg.format(k_dtype)) - rows = k + lax.broadcasted_iota(k_dtype, (N, M), 0) - cols = lax.broadcasted_iota(k_dtype, (N, M), 1) - return lax.convert_element_type(lax.eq(rows, cols), dtype) + return lax._eye(dtype, (N, M), k) @_wraps(onp.identity) @@ -1841,16 +1838,11 @@ def identity(n, dtype=None): @_wraps(onp.arange) def arange(start, stop=None, step=None, dtype=None): lax._check_user_dtype_supported(dtype, "arange") - # If called like np.arange(N), we create a lazy lax._IotaConstant. if stop is None and step is None: dtype = dtype or _dtype(start) - if issubdtype(dtype, integer): - return lax.iota(dtype, start) # avoids materializing - - # Fall back to instantiating an ndarray in host memory - dtype = dtype or result_type( - *(x for x in (start, stop, step) if x is not None)) - return onp.arange(start, stop=stop, step=step, dtype=dtype) + return lax.iota(dtype, start) # avoids materializing + else: + return array(onp.arange(start, stop=stop, step=step, dtype=dtype)) def _wrap_numpy_nullary_function(f): @@ -2074,13 +2066,7 @@ def tri(N, M=None, k=0, dtype=None): lax._check_user_dtype_supported(dtype, "tri") M = M if M is not None else N dtype = dtype or float32 - x = arange(N, dtype=int32) - y = arange(M, dtype=int32) - mask = lax.ge( - (lax.broadcast_in_dim(x, shape=(N, M), broadcast_dimensions=(0,)) + - int32(k)), - lax.broadcast(y, [N])) - return lax.convert_element_type(mask, dtype) + return lax._tri(dtype, (N, M), k) @_wraps(onp.tril) @@ -2339,7 +2325,7 @@ def sum_repeats(operand, names, counts, keep_names): for name, count in counts.items(): if count > 1: axes = [i for i, n in enumerate(names) if n == name] - eye = lax.broadcasted_eye(operand.dtype, operand.shape, axes) + eye = lax._delta(operand.dtype, operand.shape, axes) if name not in keep_names: operand = sum(operand * eye, axes) names = names.replace(name, '') @@ -2849,7 +2835,7 @@ def _index_to_gather(x_shape, idx): collapsed_slice_dims = [] start_index_map = [] - gather_indices = zeros((0,), dtype=int32) + gather_indices = onp.zeros((0,), dtype=int32) # use onp to save a compilation # We perform three transformations to y before the scatter op, in order: # First, y is broadcast to slice_shape. In general `y` only need broadcast to diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 5129d0ff112b..d8f8a2f9a67b 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -20,6 +20,7 @@ import scipy.special as osp_special from .. import lax +from .. import util from ..api import custom_transforms, defjvp from ..numpy import lax_numpy as np from ..numpy.lax_numpy import (_wraps, asarray, _reduction_dims, _constant_like, @@ -84,7 +85,7 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None or return_sign: raise NotImplementedError("Only implemented for b=None, return_sign=False") dims = _reduction_dims(a, axis) - shape = lax.subvals(onp.shape(a), zip(dims, (1,) * len(dims))) + shape = util.subvals(onp.shape(a), zip(dims, (1,) * len(dims))) dimadd = lambda x: lax.reshape(x, shape) amax = lax.reduce(a, _constant_like(a, -onp.inf), lax.max, dims) amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)) diff --git a/jax/util.py b/jax/util.py index b78bd20e2f6f..0ad1e63d0ca6 100644 --- a/jax/util.py +++ b/jax/util.py @@ -56,6 +56,12 @@ def unzip3(xyzs): zs.append(z) return tuple(xs), tuple(ys), tuple(zs) +def subvals(lst, replace): + lst = list(lst) + for i, v in replace: + lst[i] = v + return tuple(lst) + def split_list(args, ns): assert type(ns) is list args = list(args) diff --git a/tests/api_test.py b/tests/api_test.py index 286ef024398d..6b83a01a67ec 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -17,6 +17,7 @@ from __future__ import print_function import collections +from contextlib import contextmanager import copy from functools import partial import unittest @@ -24,7 +25,7 @@ import weakref from absl import logging -from absl.testing import absltest +from absl.testing import absltest, parameterized import numpy as onp import six @@ -34,7 +35,7 @@ import jax import jax.numpy as np from jax import jit, grad, device_put, jacfwd, jacrev, hessian -from jax import api, core, lax +from jax import api, core, lax, lax_reference from jax.core import Primitive from jax.interpreters import ad from jax.interpreters import xla @@ -1344,7 +1345,7 @@ def f(x): python_should_be_executing = False api.pmap(f, 'i')(x) - def test_repr(self): + def test_device_array_repr(self): rep = repr(np.ones(()) + 1.) self.assertStartsWith(rep, 'DeviceArray') @@ -1694,5 +1695,186 @@ def f(x): """) +class LazyTest(jtu.JaxTestCase): + + @contextmanager + def count_compiles(self): + + make_computation_builder = xb.make_computation_builder + count = [0] + + def make_computation_builder_and_count(*args, **kwargs): + count[0] += 1 + return make_computation_builder(*args, **kwargs) + + xb.make_computation_builder = make_computation_builder_and_count + try: + yield count + finally: + xb.make_computation_builder = make_computation_builder + + @jtu.skip_on_devices("tpu") + def test_lazy_jit_closed_over_values(self): + if not core.skip_checks: + raise SkipTest("oom test skipped when core.skip_checks is False") + + y = np.arange(int(1e12)) # will likely oom if materialized + ans = jit(lambda x: (x + y)[1])(1) + self.assertEqual(ans, 2) + + def test_jit_forces_arguments(self): + + @api.jit + def f(x): + assert python_should_be_executing + return np.sum(x) + + x = np.arange(10, dtype=np.int32) + assert xla.is_device_constant(x) # lazy iota + + python_should_be_executing = True + _ = f(x) + + python_should_be_executing = False # should not recompile + x = onp.arange(10, dtype=onp.int32) + _ = f(x) + + @parameterized.parameters(jtu.cases_from_list(range(10000))) + def test_random_lazy_program(self, seed): + + def random_array(rng): + kind = rng.choice(['arr', 'iota', 'eye', 'tri']) + if kind == 'arr': + dtype = [onp.float32, onp.int32][rng.choice(2)] + dim = rng.randint(4) + shape = rng.randint(4, size=dim) + onp_x = onp.asarray(rng.randn(*shape), dtype=dtype) + jax_x = np.array(onp_x, dtype=dtype) + elif kind == 'iota': + dtype = [onp.float32, onp.int32][rng.choice(2)] + size = rng.randint(5) + onp_x = onp.arange(size, dtype=dtype) + jax_x = lax.iota(dtype, size) + elif kind == 'eye': + dtype = [onp.float32, onp.int32][rng.choice(2)] + N = rng.randint(2, 5) + M = None if rng.rand() < 0.5 else rng.randint(2, 5) + k = rng.choice([-1, 0, 1]) + onp_x = onp.eye(N, M, k, dtype=dtype) + jax_x = np.eye(N, M, k, dtype=dtype) + elif kind == 'tri': + dtype = [onp.float32, onp.int32][rng.choice(2)] + N = rng.randint(2, 5) + M = None if rng.rand() < 0.5 else rng.randint(2, 5) + k = rng.choice([-1, 0, 1]) + onp_x = onp.tri(N, M, k, dtype=dtype) + jax_x = np.tri(N, M, k, dtype=dtype) + else: + assert False + assert type(onp_x) is onp.ndarray and type(jax_x) is xla.DeviceArray + return onp_x, jax_x + + def random_op(rng, shape): + kind = rng.choice(['transpose', 'broadcast', 'reshape']) + if kind == 'transpose': + perm = tuple(rng.permutation(len(shape))) + return Op(partial(onp.transpose, axes=perm), + partial(lax.transpose, permutation=perm)) + elif kind == 'broadcast': + n = rng.randint(1, 3) + new_sizes = rng.randint(1, 4, size=n) + new_ndim = n + len(shape) + bcast_dims = tuple(sorted(rng.permutation(new_ndim)[:len(shape)])) + shape_iter = iter(shape) + new_sizes = iter(rng.randint(1, 4, size=n)) + new_shape = [next(shape_iter) if i in bcast_dims else next(new_sizes) + for i in range(new_ndim)] + return Op(partial(lax_reference.broadcast_in_dim, shape=new_shape, + broadcast_dimensions=bcast_dims), + partial(lax.broadcast_in_dim, shape=new_shape, + broadcast_dimensions=bcast_dims)) + elif kind == 'reshape': + new_shape = list(shape) + for _ in range(rng.randint(1, 3)): + loc = len(new_shape) and rng.randint(len(new_shape)) + new_shape.insert(loc, 1) + new_shape = tuple(new_shape) + return Op(partial(onp.reshape, newshape=new_shape), + partial(lax.reshape, new_sizes=new_shape)) + else: + assert False + Op = collections.namedtuple('Op', ['onp_fn', 'jax_fn']) + + rng = onp.random.RandomState(seed) + onp_x, jax_x = _, orig_x = random_array(rng) + ops = [] + with jtu.count_primitive_compiles() as count: + for _ in range(rng.randint(5)): + op = random_op(rng, onp.shape(onp_x)) + onp_x = op.onp_fn(onp_x) + jax_x = op.jax_fn(jax_x) + ops.append(op) + self.assertEqual(count[0], 0) + + kind = rng.choice(['closure', 'npy_value', 'force', 'add']) + if kind == 'closure': + result = api.jit(lambda x: x + jax_x)(0) + self.assertAllClose(onp_x, result, check_dtypes=False) + elif kind == 'npy_value': + self.assertAllClose(onp_x, jax_x, check_dtypes=False) + elif kind == 'force': + result = xla._force(jax_x) + self.assertAllClose(onp_x, result, check_dtypes=False) + elif kind == 'add': + result = jax_x + onp.zeros(jax_x.shape, dtype=jax_x.dtype) + self.assertAllClose(onp_x, result, check_dtypes=False) + else: + assert False + + @jit + def apply_ops(x): + for op in ops: + x = op.jax_fn(x) + return x + + jit_result = apply_ops(orig_x) + self.assertAllClose(jit_result, onp_x, check_dtypes=False) + + @jit + def apply_ops_closure(): + x = orig_x + for op in ops: + x = op.jax_fn(x) + return x + + jit_result = apply_ops_closure() + self.assertAllClose(jit_result, onp_x, check_dtypes=False) + + def test_constant_forcing_computations_cached(self): + # from https://github.com/google/jax/issues/1909 + xla._lazy_force_computation.cache_clear() # clear force compile cache + big_lazy_x = np.ones((api.device_count(), 100)) + f = api.pmap(lambda x: 2 * x) + _ = f(big_lazy_x) + + with self.count_compiles() as count: + _ = f(big_lazy_x) + self.assertEqual(count[0], 0) + + def test_zeros_ones_compilation(self): + w = np.ones(3) + np.ones(3) # ensure + has a cache entry + w.block_until_ready() + + xla._lazy_force_computation.cache_clear() # clear force compile cache + + with self.count_compiles() as count: + x = np.ones(3) + np.zeros(3) + y = np.ones(3) + np.ones(3) + + self.assertEqual(count[0], 1) + self.assertAllClose(x, onp.ones(3), check_dtypes=False) + self.assertAllClose(y, onp.ones(3) + onp.ones(3), check_dtypes=False) + + if __name__ == '__main__': absltest.main() diff --git a/tests/batching_test.py b/tests/batching_test.py index c442b6490a88..9a8fd0d2e71e 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -299,7 +299,7 @@ def testDot4(self): ans = vmap(np.dot, in_axes=(1, None))(xs, ys) expected = onp.einsum('ij,i->j', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) - + def testDot5(self): f = vmap(partial(np.einsum, 'ij,j->i'), (None, 0)) jaxpr = make_jaxpr(f)(np.zeros((1000, 1000)), np.zeros((1000, 1000))) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index bf4580852cde..bf5684a19c26 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -32,6 +32,7 @@ from jax import numpy as lnp from jax import ops from jax import test_util as jtu +from jax import util from jax.config import config config.parse_flags_with_absl() @@ -426,8 +427,8 @@ def _ReplaceSlicesWithTuples(self, idx): isnone = [i for i, elt in enumerate(triple) if elt is None] zeros = itertools.repeat(0) nones = itertools.repeat(None) - out = lax.subvals(triple, zip(isnone, zeros)) - return out, lambda out: slice(*lax.subvals(out, zip(isnone, nones))) + out = util.subvals(triple, zip(isnone, zeros)) + return out, lambda out: slice(*util.subvals(out, zip(isnone, nones))) elif isinstance(idx, (tuple, list)) and idx: t = type(idx) elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) @@ -630,7 +631,7 @@ def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer): args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] def fun(x, indexer_with_dummies): - idx = type(indexer)(lax.subvals(indexer_with_dummies, substitutes)) + idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) return x[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index d3c25c7997d0..b7d585df8feb 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -40,7 +40,7 @@ from jax import test_util as jtu from jax import dtypes from jax import tree_util -from jax.interpreters import partial_eval +from jax.interpreters import partial_eval, xla from jax.test_util import check_grads from jax.config import config @@ -2074,11 +2074,10 @@ def testArange(self): self.assertAllClose(lnp.arange(53, 5, -3), onp.arange(53, 5, -3, dtype=lnp.int_), check_dtypes=True) - # TODO(mattjj): make these tests work when jax_enable_x64=True - # self.assertAllClose(lnp.arange(77, dtype=float), - # onp.arange(77, dtype=float), check_dtypes=True) - # self.assertAllClose(lnp.arange(2, 13, dtype=int), - # onp.arange(2, 13, dtype=int), check_dtypes=True) + self.assertAllClose(lnp.arange(77, dtype=float), + onp.arange(77, dtype=float), check_dtypes=True) + self.assertAllClose(lnp.arange(2, 13, dtype=int), + onp.arange(2, 13, dtype=int), check_dtypes=True) self.assertAllClose(lnp.arange(0, 1, -0.5), onp.arange(0, 1, -0.5, dtype=lnp.float_), check_dtypes=True) @@ -2095,6 +2094,10 @@ def testArange(self): self.assertTrue(type(lnp.arange(77, dtype=lnp.int32)) == type(lax.iota(onp.int32, 77))) + # test laziness for int dtypes + self.assertTrue(xla.is_device_constant(lnp.arange(77))) + self.assertTrue(xla.is_device_constant(lnp.arange(77, dtype=lnp.int32))) + def testIssue830(self): a = lnp.arange(4, dtype=lnp.complex64) self.assertEqual(a.dtype, lnp.complex64) diff --git a/tests/lax_test.py b/tests/lax_test.py index 6ad052648e07..e5a625cc5783 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1543,8 +1543,8 @@ def testDynamicUpdateSliceTypeErrors(self): onp.zeros((2, 2), dtype=onp.float32), (onp.int32(1), onp.int16(2)))) -class DeviceConstantTest(jtu.JaxTestCase): - def _CheckDeviceConstant(self, make_const, expected): +class LazyConstantTest(jtu.JaxTestCase): + def _Check(self, make_const, expected): # check casting to ndarray works asarray_result = onp.asarray(make_const()) @@ -1575,7 +1575,7 @@ def testFilledConstant(self, shape, fill_value, dtype): make_const = lambda: lax.full(shape, fill_value, dtype) expected = onp.full(shape, fill_value, dtype or dtypes.result_type(fill_value)) - self._CheckDeviceConstant(make_const, expected) + self._Check(make_const, expected) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_dim={}".format( @@ -1595,7 +1595,7 @@ def testIotaConstant(self, dtype, shape, dimension): singleton_shape[dimension] = shape[dimension] expected = onp.broadcast_to(arr.reshape(singleton_shape), shape) - self._CheckDeviceConstant(make_const, expected) + self._Check(make_const, expected) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axes={}".format( @@ -1612,13 +1612,12 @@ def testIotaConstant(self, dtype, shape, dimension): [(2, 3, 4, 2), (0, 2, 3)], [(1001, 1001), (0, 1)], ])) - def testEyeConstant(self, dtype, shape, axes): - make_const = lambda: lax.broadcasted_eye(dtype, shape, axes) - + @jtu.skip_on_devices("tpu") # TODO(mattjj): investigate failure + def testDeltaConstant(self, dtype, shape, axes): + make_const = lambda: lax._delta(dtype, shape, axes) # don't check the asarray case, just assume it's right expected = onp.asarray(make_const()) - - self._CheckDeviceConstant(make_const, expected) + self._Check(make_const, expected) GradTestSpec = collections.namedtuple( diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 4c212b7c5eac..303baccfa6eb 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -161,6 +161,40 @@ def test_primitive_compilation_cache(self): self.assertEqual(y.device_buffer.device(), jax.devices()[1]) self.assertEqual(z.device_buffer.device(), jax.devices()[1]) + def test_device_put(self): + if len(jax.devices()) < 2: + raise SkipTest("test requires multiple devices") + + # test device_put on regular values + x = jax.device_put(1, device=jax.devices()[0]) + self.assertEqual(x.device_buffer.device(), jax.devices()[0]) + + # test device_put on its own output + y = jax.device_put(x, device=jax.devices()[1]) + self.assertEqual(y.device_buffer.device(), jax.devices()[1]) + + # test device_put on lazy values + x = jax.device_put(np.zeros(2), device=jax.devices()[0]) + self.assertEqual(x.device_buffer.device(), jax.devices()[0]) + + y = jax.device_put(x, device=jax.devices()[1]) + self.assertEqual(y.device_buffer.device(), jax.devices()[1]) + + x = jax.device_put(np.zeros(2), device=jax.devices()[1]) + self.assertEqual(x.device_buffer.device(), jax.devices()[1]) + + def test_closed_over_values_device_placement(self): + # see https://github.com/google/jax/issues/1431 + if len(jax.devices()) < 2: + raise SkipTest("test requires multiple devices") + + def f(): return lax.add(3., 4.) + self.assertIsInstance(f(), xla.DeviceArray) + self.assertEqual(f().device_buffer.device(), jax.devices()[0]) + self.assertEqual(jax.jit(f)().device_buffer.device(), jax.devices()[0]) + self.assertEqual(jax.jit(f, device=jax.devices()[1])().device_buffer.device(), + jax.devices()[1]) + if __name__ == '__main__': absltest.main() diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 7f54e45d75e7..ea1421e241ae 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -139,6 +139,18 @@ def get_arr(scale): self.assertEqual(b.device_buffer.device(), api.devices('cpu')[0]) self.assertEqual(c.device_buffer.device(), api.devices('cpu')[0]) + @jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends + def test_closed_over_values_device_placement(self): + # see https://github.com/google/jax/issues/1431 + if len(jax.devices()) < 2: + raise SkipTest("test requires multiple devices") + + def f(): return lax.add(3., 4.) + self.assertNotEqual(jax.jit(f)().device_buffer.device(), + api.devices('cpu')[0]) + self.assertEqual(jax.jit(f, backend='cpu')().device_buffer.device(), + api.devices('cpu')[0]) + if __name__ == "__main__": absltest.main() diff --git a/tests/pmap_test.py b/tests/pmap_test.py index aa4fe994c38e..74afbfa82c1e 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -483,6 +483,7 @@ def testPmapConstant(self): self.assertAllClose(ans, expected, check_dtypes=False) f = pmap(lambda x: (x, 3)) + x = onp.arange(device_count) with jtu.count_jit_and_pmap_compiles() as count: _, ans = f(x) self.assertEqual(count[0], 1)