Skip to content

Commit

Permalink
omnistaging wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jun 27, 2020
1 parent 496cde6 commit 5e8507a
Show file tree
Hide file tree
Showing 44 changed files with 1,292 additions and 1,445 deletions.
1 change: 0 additions & 1 deletion docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ Operators
squeeze
sub
tan
tie_in
top_k
transpose

Expand Down
430 changes: 184 additions & 246 deletions docs/jaxpr.rst

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion jax/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


from jax import core
from .core import (lattice_join, Primitive, Unit, unit, AbstractUnit,
valid_jaxtype, raise_to_shaped, get_aval)
from .tree_util import register_pytree_node
Expand All @@ -27,7 +28,10 @@
jaxval_adders[Unit] = lambda _, __: unit

def add_jaxvals(x, y):
return add_jaxvals_p.bind(x, y)
if core.get_aval(x) is core.get_aval(y) is core.abstract_unit:
return core.unit
else:
return add_jaxvals_p.bind(x, y)

add_jaxvals_p = Primitive('add_any')

Expand Down
108 changes: 12 additions & 96 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ def xla_computation(fun: Callable,
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
backend: Optional[str] = None,
tuple_args: bool = False,
instantiate_const_outputs: bool = True) -> Callable:
tuple_args: bool = False) -> Callable:
"""Creates a function that produces its XLA computation given example args.
Args:
Expand All @@ -246,13 +245,6 @@ def xla_computation(fun: Callable,
tuple_args: Optional bool, defaults to ``False``. If ``True``, the resulting
XLA computation will have a single tuple argument that is unpacked into
the specified function arguments.
instantiate_const_outputs: Optional bool, defaults to ``True``. If
``False``, then :py:func:`xla_computation` does not instantiate
constant-valued outputs in the XLA computation, and so the result is
closer to the computation that :py:func:`jax.jit` produces and may be more
useful for studying :py:func:`jit` behavior. If ``True``, then
constant-valued outputs are instantiated in the XLA computation, which may
be more useful for staging computations out of JAX entirely.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns a
Expand Down Expand Up @@ -332,11 +324,11 @@ def xla_computation(fun: Callable,

def make_axis_env(nreps):
if axis_env is None:
return xla.AxisEnv(nreps)
return xla.AxisEnv(nreps, (), (), None)
else:
nreps = nreps * prod(size for name, size in axis_env)
names, sizes = zip(*axis_env)
return xla.AxisEnv(nreps, names, sizes)
return xla.AxisEnv(nreps, names, sizes, None)

def abstractify(x):
return ShapedArray(onp.shape(x), dtypes.result_type(x))
Expand All @@ -350,10 +342,7 @@ def computation_maker(*args, **kwargs):
jax_args, in_tree = tree_flatten((args, kwargs))
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
avals = map(abstractify, jax_args)
pvals = [pe.PartialVal.unknown(aval) for aval in avals]
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals,
instantiate=instantiate_const_outputs,
stage_out=True)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr, _ = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
Expand Down Expand Up @@ -1175,8 +1164,8 @@ def __eq__(self, other):
return type(other) is _TempAxisName and self.obj == other.obj


def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, *,
in_axes=0, backend: Optional[str] = None) -> Callable:
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
) -> Callable:
warn("soft_pmap is an experimental feature and probably has bugs!")
_check_callable(fun)
axis_name = _TempAxisName(fun) if axis_name is None else axis_name
Expand All @@ -1193,45 +1182,11 @@ def f_pmapped(*args, **kwargs):
axis_size = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "soft_pmap")
for arg in args_flat: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)

chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count(backend))
if chunk_size == 0 and leftover:
return pmap(fun, axis_name, backend=backend)(*args) # can map directly onto hardware
elif leftover:
msg = ("soft_pmap mapped axis size must be divisible by the number of "
"XLA devices (or be less than or equal to that number), but got "
"an axis size of {} with {} devices.")
raise ValueError(msg.format(axis_size, pxla.unmapped_device_count()))
num_chunks = axis_size // chunk_size

reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat]
soft_mapped_fun = pxla.split_axis(flat_fun, axis_name, chunk_size)
# TODO(tomhennigan): soft_pmap should support buffer donation.
donated_invars = (False,) * len(reshaped_args)
reshaped_outs = pxla.xla_pmap(soft_mapped_fun, *reshaped_args, backend=backend,
axis_name=axis_name, axis_size=num_chunks,
global_axis_size=None, devices=None,
name=soft_mapped_fun.__name__,
mapped_invars=mapped_invars,
donated_invars=donated_invars)
outs = [_reshape_merge(out) for out in reshaped_outs]
outs = pxla.soft_pmap(flat_fun, *args_flat, axis_name=axis_name,
axis_size=axis_size, mapped_invars=mapped_invars)
return tree_unflatten(out_tree(), outs)
return f_pmapped

def _reshape_split(num_chunks, x):
aval = core.get_aval(x)
if aval is core.abstract_unit:
return x
else:
return x.reshape((num_chunks, x.shape[0] // num_chunks) + x.shape[1:])

def _reshape_merge(x):
aval = core.get_aval(x)
if aval is core.abstract_unit:
return x
else:
return x.reshape((-1,) + x.shape[2:])


def _papply(fun):
# This function is for testing purposes.
Expand All @@ -1249,37 +1204,6 @@ def papply_fun(*args, **kwargs):
return papply_fun, axis_name


def _parallelize(fun):
axis_name = _TempAxisName(fun)

def pfun(*args):
f = lu.wrap_init(fun)
args_flat, in_tree = tree_flatten(args)
f, out_tree = flatten_fun_nokwargs(f, in_tree)
axis_size = _mapped_axis_size(
in_tree, args_flat, (0,) * len(args_flat), "parallelize")

chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count())
if chunk_size == 0 and leftover:
return pmap(fun, axis_name)(*args) # can map directly onto hardware
elif leftover:
raise ValueError
num_chunks = axis_size // chunk_size

reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat]
f, out_axes = parallel.papply_transform(f, axis_name, axis_size)
f = pxla.split_axis(f, axis_name, chunk_size)
outs = pxla.xla_pmap(f, *reshaped_args, backend=None, axis_name=axis_name,
axis_size=num_chunks, global_axis_size=None,
devices=None, name=f.__name__)
outs = map(_reshape_merge, outs)
outs = [batching.matchaxis(axis_size, 0, dst, x)
for dst, x in zip(out_axes(), outs)]
return tree_unflatten(out_tree(), outs)

return pfun


def mask(fun: Callable, in_shapes, out_shape) -> Callable:
_check_callable(fun)
unique_ids = masking.UniqueIds()
Expand Down Expand Up @@ -1616,10 +1540,6 @@ def make_jaxpr(fun: Callable,
if isinstance(static_argnums, int):
static_argnums = (static_argnums,)

def pv_like(x):
aval = xla.abstractify(x)
return pe.PartialVal.unknown(aval)

@wraps(fun)
def jaxpr_maker(*args, **kwargs):
wrapped = lu.wrap_init(fun)
Expand All @@ -1628,11 +1548,9 @@ def jaxpr_maker(*args, **kwargs):
wrapped, _ = argnums_partial(wrapped, dyn_argnums, args)
jax_args, in_tree = tree_flatten((args, kwargs))
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
in_pvals = map(pv_like, jax_args)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, in_pvals, instantiate=True, stage_out=True)
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
in_avals = tuple(raise_to_shaped(in_aval) for in_aval, _ in in_pvals)
in_avals = map(xla.abstractify, jax_args)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
in_avals = tuple(raise_to_shaped(in_aval) for in_aval in in_avals)
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
return typed_jaxpr

Expand Down Expand Up @@ -1896,13 +1814,11 @@ def __repr__(self):
return '<jax.custom_transforms function {fun}>'.format(fun=self.__name__)

def __call__(self, *args):
# TODO(mattjj): instead of tracing to a jaxpr, use process_call
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_pvals = [pe.PartialVal.unknown(raise_to_shaped(core.get_aval(x)))
for x in args_flat]
with core.initial_style_staging():
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
in_tree=in_tree, out_tree=out_tree(),
num_consts=len(consts))
Expand Down
Loading

0 comments on commit 5e8507a

Please sign in to comment.