Skip to content

Commit

Permalink
lazy sublanguage
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Dec 31, 2019
1 parent 322ebe7 commit cfcf4f1
Show file tree
Hide file tree
Showing 16 changed files with 755 additions and 349 deletions.
2 changes: 1 addition & 1 deletion jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, val):
if type(val) in literalable_types:
try:
self.hash = hash((val.item(), val.dtype))
except (TypeError, AttributeError):
except (TypeError, AttributeError, ValueError):
self.hash = None

def __hash__(self):
Expand Down
24 changes: 16 additions & 8 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def _shard_array(x, devices, assignments):
def _shard_device_array(x, devices, assignments):
nrep = len(devices)
xs = x._unstack()

return (xla.device_put(xs[assignments[r]], devices[r])
for r in range(nrep))
shard_arg_handlers[xla.DeviceArray] = _shard_device_array
Expand Down Expand Up @@ -376,11 +377,14 @@ def _shard_sharded_device_array(x, devices, assignments):
return (xla.device_put(x[assignments[r]], devices[r]) for r in range(n))
shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array

def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
return c.Constant(onp.asarray(val), canonicalize_types=canonicalize_types)
xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler)

core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
xla.device_put_handlers[ShardedDeviceArray] = xla._device_put_array
xla.pytype_aval_mappings[ShardedDeviceArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[ShardedDeviceArray] = identity
xb.register_constant_handler(ShardedDeviceArray, xla._device_array_constant_handler)


class ChunkedDeviceArray(ShardedDeviceArray):
Expand Down Expand Up @@ -413,9 +417,9 @@ def xla_pmap_impl(fun, *args, **params):
backend = params.pop('backend', None)
assert not params

abstract_args = map(xla.abstractify, args)
avals = [xla.abstractify(x) for x in args]
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size, devices,
*abstract_args)
*avals)
return compiled_fun(*args)

@lu.cache
Expand Down Expand Up @@ -454,8 +458,8 @@ def dynamic_fun(dummy, *args):
with extend_dynamic_axis_env(axis_name, dummy.trace, global_axis_size):
return fun.call_wrapped(*args)

avals = tuple(map(partial(shard_aval, axis_size), avals))
pvals = [pe.PartialVal((aval, core.unit)) for aval in avals]
sharded_avals = tuple(map(partial(shard_aval, axis_size), avals))
pvals = [pe.PartialVal((aval, core.unit)) for aval in sharded_avals]
pval = pe.PartialVal([core.abstract_unit, core.unit]) # dummy value for axis env
with core.new_master(pe.StagingJaxprTrace, True) as master:
jaxpr, (out_pvals, consts, env) = pe.trace_to_subjaxpr(
Expand Down Expand Up @@ -493,11 +497,11 @@ def dynamic_fun(dummy, *args):
num_global_replicas = global_axis_size * jaxpr_replicas
axis_env = xla.AxisEnv(num_global_replicas, [axis_name], [global_axis_size], devices)

tuple_args = len(avals) > 100 # pass long arg lists as tuple for TPU
tuple_args = len(sharded_avals) > 100 # pass long arg lists as tuple for TPU

c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
xla_consts = _map(c.Constant, consts)
xla_args = xla._xla_callable_args(c, avals, tuple_args)
xla_args = _pmap_callable_args(c, sharded_avals, tuple_args)
out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, (), *xla_args)
built = c.Build(c.Tuple(*out_nodes))

Expand Down Expand Up @@ -537,6 +541,10 @@ def dynamic_fun(dummy, *args):
class ResultToPopulate(object): pass
result_to_populate = ResultToPopulate()

def _pmap_callable_args(c, avals, tuple_args):
# TODO(mattjj): support laziness for broadcasted axes to map
return xla._xla_callable_args(c, avals, tuple_args)

def _pvals_to_results_handler(size, nrep, out_pvals, devices, backend):
nouts = len(out_pvals)
handlers = [_pval_to_result_handler(size, nrep, pval, devices, backend)
Expand Down Expand Up @@ -585,7 +593,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None):

aval = xla.abstractify(val)
aval = ShapedArray((axis_size,) + aval.shape, aval.dtype)
device_buffers = [xla.device_put(val, d) for d in devices]
device_buffers = [xla.device_put(xla._force(val), d) for d in devices]
return ShardedDeviceArray(aval, device_buffers)

def _pval_to_result_handler(axis_size, nrep, pval, devices, backend):
Expand Down
Loading

0 comments on commit cfcf4f1

Please sign in to comment.