Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy sublanguage #1668

Merged
merged 1 commit into from
Jan 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..config import flags
from .. import core
from .. import linear_util as lu
from .. import lazy
from ..abstract_arrays import (ConcreteArray, ShapedArray, array_types,
raise_to_shaped)
from ..util import partial, unzip2, concatenate, prod, safe_map
Expand Down Expand Up @@ -360,8 +361,7 @@ def __getitem__(self, idx):
ids = self._ids()
device_buffer = self.device_buffers[ids[idx]]
aval = ShapedArray(self.aval.shape[1:], self.aval.dtype)
handler = xla.aval_to_result_handler(None, aval)
return handler(device_buffer)
return xla.DeviceArray(aval, None, lazy.array(aval.shape), device_buffer)
else:
return super(ShardedDeviceArray, self).__getitem__(idx)

Expand All @@ -376,11 +376,14 @@ def _shard_sharded_device_array(x, devices, assignments):
return (xla.device_put(x[assignments[r]], devices[r]) for r in range(n))
shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array

def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
return c.Constant(onp.asarray(val), canonicalize_types=canonicalize_types)
xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler)

core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
xla.device_put_handlers[ShardedDeviceArray] = xla._device_put_array
xla.pytype_aval_mappings[ShardedDeviceArray] = lambda x: x.aval
xla.pytype_aval_mappings[ShardedDeviceArray] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ShardedDeviceArray] = identity
xb.register_constant_handler(ShardedDeviceArray, xla._device_array_constant_handler)


class ChunkedDeviceArray(ShardedDeviceArray):
Expand All @@ -398,10 +401,8 @@ def __getitem__(self, idx):

core.pytype_aval_mappings[ChunkedDeviceArray] = ConcreteArray
xla.device_put_handlers[ChunkedDeviceArray] = xla._device_put_array
xla.pytype_aval_mappings[ChunkedDeviceArray] = lambda x: x.aval
xla.pytype_aval_mappings[ChunkedDeviceArray] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ChunkedDeviceArray] = identity
xb.register_constant_handler(ChunkedDeviceArray,
xla._device_array_constant_handler)


### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
Expand Down
138 changes: 93 additions & 45 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .. import ad_util
from .. import tree_util
from .. import dtypes
from .. import lazy
from .. import linear_util as lu
from ..abstract_arrays import (ConcreteArray, ShapedArray, AbstractToken,
make_shaped_array, array_types, raise_to_shaped,
Expand Down Expand Up @@ -78,7 +79,7 @@ def aval_to_result_handler(device, aval):
xla_result_handlers = {}
xla_result_handlers[core.AbstractUnit] = lambda _, __: lambda _: core.unit
def array_result_handler(device, aval):
return partial(DeviceArray, raise_to_shaped(aval), device)
return partial(DeviceArray, raise_to_shaped(aval), device, lazy.array(aval.shape))
xla_result_handlers[ShapedArray] = array_result_handler
xla_result_handlers[ConcreteArray] = array_result_handler

Expand Down Expand Up @@ -146,6 +147,7 @@ def _make_abstract_python_scalar(typ, _):
for _t in dtypes.python_scalar_dtypes.keys():
pytype_aval_mappings[_t] = partial(_make_abstract_python_scalar, _t)


### op-by-op execution

def arg_spec(x):
Expand Down Expand Up @@ -536,6 +538,7 @@ def _xla_callable_args(c, avals, tuple_args):
def _pval_to_result_handler(device, pval):
pv, const = pval
if pv is None:
const = _device_put_impl(const, device) if device else const
return lambda _: const
else:
return aval_to_result_handler(device, pv)
Expand Down Expand Up @@ -566,8 +569,8 @@ def _execute_trivial(jaxpr, device, consts, handlers, *args):
_map(env.setdefault, jaxpr.constvars, consts)
outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
for v in jaxpr.outvars]
return [x if type(x) is DeviceArray else handler(device_put(x, device))
for handler, x in zip(handlers, outs)]
return [_copy_device_array_to_device(x, device) if type(x) is DeviceArray
else h(device_put(x, device)) for h, x in zip(handlers, outs)]

def make_tuple(bufs, device, backend):
return xb.get_backend(backend).make_tuple(bufs, device)
Expand Down Expand Up @@ -678,7 +681,7 @@ def __init__(self, aval, device_buffer):
self.device_buffer = device_buffer

def _check_if_deleted(self):
if self.device_buffer is None:
if self.device_buffer is deleted_buffer:
raise ValueError("DeviceValue has been deleted.")

def block_until_ready(self):
Expand All @@ -702,13 +705,14 @@ class DeviceArray(DeviceValue):
"""A DeviceArray is an ndarray backed by a single device memory buffer."""
# We don't subclass ndarray because that would open up a host of issues,
# but lax_numpy.py overrides isinstance behavior and attaches ndarray methods.
__slots__ = ["_npy_value", "_device"]
__slots__ = ["_npy_value", "_device", "_lazy_expr"]
__array_priority__ = 100

def __init__(self, aval, device, device_buffer):
def __init__(self, aval, device, lazy_expr, device_buffer):
self.aval = aval
self.device_buffer = device_buffer
self._device = device and (type(device), device.id)
self._lazy_expr = lazy_expr

self._npy_value = None
if not core.skip_checks:
Expand All @@ -720,7 +724,10 @@ def __init__(self, aval, device, device_buffer):
def _value(self):
self._check_if_deleted()
if self._npy_value is None:
self._npy_value = self.device_buffer.to_py()
if is_device_constant(self):
self._npy_value = lazy.eval_lexpr(self._lazy_expr, None)
else:
self._npy_value = _force(self).device_buffer.to_py()
self._npy_value.flags.writeable = False
return self._npy_value

Expand All @@ -747,7 +754,7 @@ def copy(self):
def copy_to_host_async(self):
"""Requests a copy of the buffer to the host."""
self._check_if_deleted()
if self._npy_value is None:
if self._npy_value is None and not is_device_constant(self):
self.device_buffer.copy_to_host_async()

def delete(self):
Expand All @@ -762,7 +769,7 @@ def delete(self):
time of deletion.
"""
self.device_buffer.delete()
self.device_buffer = None
self.device_buffer = deleted_buffer
self._npy_value = None

def __repr__(self):
Expand Down Expand Up @@ -837,31 +844,97 @@ def __eq__(self, other): return self._value == other
def __hash__(self):
raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")

class DeletedBuffer(object): pass
deleted_buffer = DeletedBuffer()

class DeviceConstant(object):
__slots__ = ["_device"]
def __init__(self, device=None): self._device = device
def device(self): return self._device
def to_py(self): return None

def is_device_constant(x):
return type(x) is DeviceArray and type(x.device_buffer) is DeviceConstant

core.literalable_types.add(DeviceArray)
core.pytype_aval_mappings[DeviceArray] = ConcreteArray
pytype_aval_mappings[DeviceArray] = lambda x: x.aval
pytype_aval_mappings[DeviceArray] = op.attrgetter('aval')
canonicalize_dtype_handlers[DeviceArray] = identity

def _device_array_constant_handler(c, val, canonicalize_types=True):
return c.Constant(onp.asarray(val), canonicalize_types=canonicalize_types)
if is_device_constant(val):
return lazy.stage_lexpr(c, val._lazy_expr, None)
else:
base_val = c.Constant(val.device_buffer.to_py())
return lazy.stage_lexpr(c, val._lazy_expr, base_val)
xb.register_constant_handler(DeviceArray, _device_array_constant_handler)

def _device_put_device_array(x, device):
# TODO(skye): we're assuming the DeviceBuffers without "platform" are
# XrtBuffers. Figure out a less risky way to deal with XrtBuffers.
if (not hasattr(x.device_buffer, "platform") or
xb.get_device_backend(device).platform == x.device_buffer.platform()):
x = _copy_device_array_to_device(x, device)
return _force(x).device_buffer
device_put_handlers[DeviceArray] = _device_put_device_array

def _copy_device_array_to_device(x, device):
if is_device_constant(x):
return DeviceArray(x.aval, device, x._lazy_expr, DeviceConstant(device))
elif xb.get_device_backend(device).platform == x.device_buffer.platform():
if device is None or x.device_buffer.device() == device:
return x.device_buffer
return x
else:
return x.device_buffer.copy_to_device(device)
moved_buf = x.device_buffer.copy_to_device(device)
else:
# Buffers from different XLA backends are passed through the host.
return xc.Buffer.from_pyval(x, device, backend=xb.get_device_backend(device))
device_put_handlers[DeviceArray] = _device_put_device_array
# Buffers from different XLA backends are passed through the host.
moved_buf = xc.Buffer.from_pyval(x.device_buffer.to_py(), device,
backend=xb.get_device_backend(device))
return DeviceArray(x.aval, device, x._lazy_expr, moved_buf)

def _force(x):
if lazy.is_trivial(x._lazy_expr):
return x
else:
# force x on the device where it lives, but preserve stickiness on result
if x._device:
device = x._device
sticky = True
else:
d = x.device_buffer.device()
device = d and (type(d), d.id)
sticky = False
force_fun = _lazy_force_computation(sticky, x.aval, device, x._lazy_expr)
return force_fun(x)

@cache()
def _lazy_force_computation(sticky, aval, device, lexpr):
c = xb.make_computation_builder("lazy_force")
if lazy.is_constant(lexpr):
param = None
else:
idxs = [(src, dst) for dst, src in enumerate(lexpr.dims) if src is not None]
param_shape = [None] * len(idxs)
for src, dst in idxs:
param_shape[src] = aval.shape[dst]
param = c.ParameterWithShape(xc.Shape.array_shape(aval.dtype, param_shape))
xla_out = lazy.stage_lexpr(c, lexpr, param)
built_c = c.Build(xla_out)

device = _device_from_arg_devices([device])
options = xb.get_compile_options(device_assignment=device and (device.id,))
backend = xb.get_device_backend(device)
compiled = built_c.Compile(compile_options=options, backend=backend)

result_device = device if sticky else None
handler = partial(DeviceArray, aval, result_device, lazy.array(aval.shape))
if lazy.is_constant(lexpr):
force_fun = lambda _: handler(compiled.Execute([]))
else:
force_fun = lambda x: handler(compiled.Execute([x.device_buffer]))
return force_fun


def _device_put_impl(x, device=None):
if type(x) is DeviceArray:
return _copy_device_array_to_device(x, device)

try:
a = abstractify(x)
except TypeError:
Expand Down Expand Up @@ -904,28 +977,3 @@ def _foil_cse(c, x):
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
return c.Select(pred, x, zero)


### lazy constants

class DeviceConstant(DeviceArray):
def copy_to_host_async(self): pass

@staticmethod
def constant_handler(c, constant_instance, canonicalize_types=True):
assert False

def _instantiate_device_constant(const, device=None, backend=None, cutoff=1e6):
# dispatch an XLA Computation to build the constant on the device if it's
# large, or alternatively build it on the host and transfer it if it's small
assert isinstance(const, DeviceConstant)
backend = xb.get_backend(device.platform) if device else xb.get_backend(backend)
if const.size > cutoff:
c = xb.make_computation_builder("constant_instantiating_computation")
xla_const = const.constant_handler(c, const)
device_assignment = (device.id,) if device else None
opts = xb.get_compile_options(device_assignment=device_assignment)
compiled = c.Build(xla_const).Compile((), opts, backend=backend)
return compiled.Execute(())
else:
return xc.Buffer.from_pyval(onp.asarray(const), device, backend=backend)
2 changes: 1 addition & 1 deletion jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
_input_dtype, _const, _eq_meet, _safe_mul,
_broadcasting_select, _check_user_dtype_supported,
_one, _const, _upcast_fp16_for_computation,
_broadcasting_shape_rule)
_broadcasting_shape_rule, _eye, _tri, _delta)
from .lax_control_flow import *
from .lax_fft import *
from .lax_parallel import *
Loading