Skip to content

Commit

Permalink
Remove device constants from lazy language.
Browse files Browse the repository at this point in the history
Updated version of jax-ml#4536.

This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
  • Loading branch information
hawkinsp committed Mar 4, 2021
1 parent 6c102d9 commit b142653
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 216 deletions.
51 changes: 14 additions & 37 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,17 +1487,9 @@ def iota(dtype: DType, size: int) -> Array:
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
operator.
"""
if config.omnistaging_enabled:
dtype = dtypes.canonicalize_dtype(dtype)
size = core.concrete_or_error(int, size, "size argument of lax.iota")
return iota_p.bind(dtype=dtype, shape=(size,), dimension=0)
else:
size = size if type(size) is masking.Poly else int(size)
shape = canonicalize_shape((size,))
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.iota(dtype, shape[0])
aval = ShapedArray(shape, dtype)
return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
dtype = dtypes.canonicalize_dtype(dtype)
size = core.concrete_or_error(int, size, "size argument of lax.iota")
return iota_p.bind(dtype=dtype, shape=(size,), dimension=0)

def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array:
"""Convenience wrapper around ``iota``."""
Expand All @@ -1512,45 +1504,30 @@ def _eye(dtype: DType, shape: Shape, offset: int) -> Array:
N, M = tuple(map(int, shape))
offset = int(offset)
dtype = dtypes.canonicalize_dtype(dtype)
if config.omnistaging_enabled:
bool_eye = eq(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)),
broadcasted_iota(np.int32, (N, M), 1))
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False)
else:
lazy_expr = lazy.eye(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
bool_eye = eq(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)),
broadcasted_iota(np.int32, (N, M), 1))
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False)

def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
"""This utility function exists for creating Kronecker delta arrays."""
shape = tuple(map(int, shape))
axes = tuple(map(int, axes))
dtype = dtypes.canonicalize_dtype(dtype)
base_shape = tuple(np.take(shape, axes)) # type: ignore[arg-type]
if config.omnistaging_enabled:
iotas = [broadcasted_iota(np.uint32, base_shape, i)
for i in range(len(base_shape))]
eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
result = convert_element_type_p.bind(_reduce(operator.and_, eyes), new_dtype=dtype, weak_type=False)
return broadcast_in_dim(result, shape, axes)
else:
lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes)
aval = ShapedArray(shape, dtype)
return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
iotas = [broadcasted_iota(np.uint32, base_shape, i)
for i in range(len(base_shape))]
eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
result = convert_element_type_p.bind(_reduce(operator.and_, eyes), new_dtype=dtype, weak_type=False)
return broadcast_in_dim(result, shape, axes)

def _tri(dtype: DType, shape: Shape, offset: int) -> Array:
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
N, M = tuple(map(int, shape))
offset = int(offset)
dtype = dtypes.canonicalize_dtype(dtype)
if config.omnistaging_enabled:
bool_tri = ge(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)),
broadcasted_iota(np.int32, (N, M), 1))
return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False)
else:
lazy_expr = lazy.tri(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla._DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
bool_tri = ge(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)),
broadcasted_iota(np.int32, (N, M), 1))
return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False)

def stop_gradient(x):
"""Stops gradient computation.
Expand Down
2 changes: 1 addition & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def cache_miss(_, *args, **kwargs):
# has been reset to None). Thus, we do not support the fast-path.
execute is not None and
execute.func is xla._execute_compiled and # not trivial, not pmap
# Not supported: ShardedDeviceArray, DeviceConstant.
# Not supported: ShardedDeviceArray
all(xla.type_is_device_array(x) for x in out_flat) and
# TODO(mattjj): Add support for lazy-expression.
# If the input is a DeviceArray, then it should have a trivial LazyExpr.
Expand Down
53 changes: 13 additions & 40 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,15 +1034,14 @@ def make_device_array(
aval: core.ShapedArray,
device: Optional[Device],
lazy_expr: Optional[lazy.LazyExpr],
device_buffer: Union[PyLocalBuffer, "DeviceConstant"],
device_buffer: PyLocalBuffer,
) -> Union[PyLocalBuffer, "_DeviceArray"]:
"""Returns a DeviceArray implementation based on arguments.
This is to be used only within JAX. It will return either a PythonDeviceArray
or a C++ equivalent implementation.
"""
if (_EXPERIMENTAL_CPP_DEVICE_ARRAY and lazy.is_trivial(lazy_expr) and
not isinstance(device_buffer, DeviceConstant)):
if _EXPERIMENTAL_CPP_DEVICE_ARRAY and lazy.is_trivial(lazy_expr):
assert isinstance(device_buffer, _CppDeviceArray)
device_buffer._device = device # pylint: disable=protected-access
device_buffer.aval = aval
Expand Down Expand Up @@ -1120,10 +1119,7 @@ def block_until_ready(self):
def _value(self):
self._check_if_deleted()
if self._npy_value is None:
if is_device_constant(self):
self._npy_value = lazy.eval_lexpr(self._lazy_expr, None)
else:
self._npy_value = _force(self).device_buffer.to_py()
self._npy_value = _force(self).device_buffer.to_py()
self._npy_value.flags.writeable = False
return self._npy_value

Expand All @@ -1146,7 +1142,7 @@ def ndim(self):
def copy_to_host_async(self):
"""Requests a copy of the buffer to the host."""
self._check_if_deleted()
if self._npy_value is None and not is_device_constant(self):
if self._npy_value is None:
self.device_buffer.copy_to_host_async() # pytype: disable=attribute-error

def delete(self):
Expand Down Expand Up @@ -1282,14 +1278,6 @@ def raise_not_implemented():
class DeletedBuffer(object): pass
deleted_buffer = DeletedBuffer()

class DeviceConstant(object):
__slots__ = ["_device"]
def __init__(self, device=None): self._device = device
def device(self): return self._device
def to_py(self): return None

def is_device_constant(x):
return type_is_device_array(x) and type(x.device_buffer) is DeviceConstant

for device_array in [_CppDeviceArray, _DeviceArray]:
core.literalable_types.add(device_array)
Expand All @@ -1298,11 +1286,8 @@ def is_device_constant(x):
canonicalize_dtype_handlers[device_array] = identity

def _device_array_constant_handler(c, val, canonicalize_types=True):
if is_device_constant(val):
return lazy.stage_lexpr(c, val._lazy_expr, None)
else:
base_val = xb.constant(c, val.device_buffer.to_py())
return lazy.stage_lexpr(c, val._lazy_expr, base_val)
base_val = xb.constant(c, val.device_buffer.to_py())
return lazy.stage_lexpr(c, val._lazy_expr, base_val)
xb.register_constant_handler(_DeviceArray, _device_array_constant_handler)
xb.register_constant_handler(_CppDeviceArray, _device_array_constant_handler)

Expand All @@ -1316,10 +1301,6 @@ def _copy_device_array_to_device(x: Union[DeviceArrayProtocol, _DeviceArray], de
if device is None:
# no copying to be done because there's no target specified
return x
elif is_device_constant(x):
# create a new DeviceArray with the same lazy expr, no copying
return make_device_array(x.aval, device, x._lazy_expr,
DeviceConstant(device))
elif xb.get_device_backend(device).platform == x.device_buffer.platform():
# source and target platforms are the same
if x.device_buffer.device() == device:
Expand Down Expand Up @@ -1355,14 +1336,11 @@ def _lazy_force_computation(aval: core.ShapedArray,
device: Device, lexpr: lazy.LazyExpr
) -> Callable[[_DeviceArray], PyLocalBuffer]:
c = xb.make_computation_builder("lazy_force")
if lazy.is_constant(lexpr):
param = None
else:
idxs = [(src, dst) for dst, src in enumerate(lexpr.dims) if src is not None]
param_shape = [None] * len(idxs)
for src, dst in idxs:
param_shape[src] = aval.shape[dst]
param = xb.parameter(c, 0, xc.Shape.array_shape(aval.dtype, param_shape))
idxs = [(src, dst) for dst, src in enumerate(lexpr.dims) if src is not None]
param_shape = [None] * len(idxs)
for src, dst in idxs:
param_shape[src] = aval.shape[dst]
param = xb.parameter(c, 0, xc.Shape.array_shape(aval.dtype, param_shape))
xla_out = lazy.stage_lexpr(c, lexpr, param)
built_c = c.build(xla_out)

Expand All @@ -1373,13 +1351,8 @@ def _lazy_force_computation(aval: core.ShapedArray,
device_assignment=device and (device.id,))
compiled = backend_compile(xb.get_device_backend(device), built_c, options)

force_fun: Callable[[_DeviceArray], PyLocalBuffer]
if lazy.is_constant(lexpr):
def force_fun(_):
return compiled.execute([])[0]
else:
def force_fun(x):
return compiled.execute([x.device_buffer])[0]
def force_fun(x: _DeviceArray) -> PyLocalBuffer:
return compiled.execute([x.device_buffer])[0]
return force_fun


Expand Down
Loading

0 comments on commit b142653

Please sign in to comment.