Skip to content

Commit

Permalink
Separate axis splitting from collective handling (#4082)
Browse files Browse the repository at this point in the history
This makes the vmap collective handling a bit more flexible and allowed
me to add ppermute support.
  • Loading branch information
apaszke authored Aug 18, 2020
1 parent ace23fa commit 36f3a36
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 32 deletions.
40 changes: 19 additions & 21 deletions jax/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,20 @@ def process_primitive(self, primitive, tracers, params):
if all(bdim is not_mapped for bdim in dims_in):
return primitive.bind(*vals_in, **params)
elif config.omnistaging_enabled and primitive in collective_rules:
axes_names = params['axis_name']
if not isinstance(axes_names, (tuple, list)):
axes_names = (axes_names,)
for i, axis_name in enumerate(axes_names):
axis_names = params['axis_name']
if not isinstance(axis_names, (tuple, list)):
axis_names = (axis_names,)
for i, axis_name in enumerate(axis_names):
frame = core.axis_frame(axis_name)
if frame.tag is self.master:
if params['axis_index_groups'] is not None:
raise NotImplementedError("axis_index_groups not supported in vmap collectives")
result = collective_rules[primitive](vals_in, dims_in, frame.size, **params)
remaining_axes = axes_names[:i] + axes_names[(i+1):]
# TODO: This assumes that the collective always returns the same result for each
# array element, which is not true for the ones that are not reductions!
if remaining_axes:
new_params = dict(params, axis_name=remaining_axes)
return primitive.bind(*result, **new_params)
else:
return result
if frame.tag is not self.master:
continue
# We run the split_axis rule with tracers, which is supposed to never
# mix this axis name with another one. We will handle any invocations
# of collectives over the vmapped axis in a recursive call from a tracer.
if len(axis_names) > 1:
return split_axis(primitive, axis_name, tracers, params)
vals_out, dims_out = collective_rules[primitive](vals_in, dims_in, frame.size, **params)
return map(partial(BatchTracer, self), vals_out, dims_out)
# TODO(mattjj,phawkins): if no rule implemented, could vmap-via-map here
batched_primitive = get_primitive_batcher(primitive)
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
Expand Down Expand Up @@ -444,9 +441,10 @@ def batch_jaxpr(jaxpr, size, batched, instantiate):
return jaxpr_out, batched_out()


# It is assumed that all collectives specified below are commutative
# and associative over axis names if they support tuples. That is,
# they have to satisfy:
# collective(x, ('i', 'j')) == collective(x, ('j', 'i'))
# == collective(collective(x, 'j'), 'i')
# collective_rules can assume that the collective is only carried out throughout
# the vmapped axis (i.e. no tuples in axis_name).
collective_rules: Dict[core.Primitive, Callable] = {}
split_axis_rules: Dict[core.Primitive, Callable] = {}

def split_axis(primitive, split_name, args, params):
return split_axis_rules[primitive](split_name, args, params)
76 changes: 65 additions & 11 deletions jax/lax/lax_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from jax.util import partial, unzip2, prod
from jax.lib import xla_client as xc
from jax.config import config
from jax.numpy import lax_numpy

from jax.interpreters.pxla import axis_index

Expand Down Expand Up @@ -305,6 +306,41 @@ def _allreduce_translation_rule(prim, c, val, *, axis_name, axis_index_groups,
replica_groups_protos = xc.make_replica_groups(replica_groups)
return xops.AllReduce(val, computation, replica_groups_protos, None, None)

# It is assumed that all collectives that use this rule are commutative
# and associative over axis names if they support tuples. That is,
# they have to satisfy:
# collective(x, ('i', 'j')) == collective(x, ('j', 'i'))
# == collective(collective(x, 'j'), 'i')
def _split_axis_comm_assoc(primitive, split_name, args, params):
axis_names = params['axis_name']
assert isinstance(axis_names, tuple)
if params['axis_index_groups'] is not None:
raise NotImplementedError("axis_index_groups not supported in axis splitting. "
"Please open a feature request!")
remaining_axes = list(axis_names)
remaining_axes.remove(split_name)
remaining_axes = tuple(remaining_axes)
split_params = dict(params, axis_name=split_name)
remain_params = dict(params, axis_name=remaining_axes)
split_result = primitive.bind(*args, **split_params)
return primitive.bind(*split_result, **remain_params)

# NB: This is only used for collectives that do not include the vmapped axis name,
# which is why the rule is so simple. All other collectives go through split_axis.
def _collective_batcher(prim, args, dims, **params):
return prim.bind(*args, **params), dims

def _batched_reduction_collective(prim, if_mapped, if_unmapped,
vals_in, dims_in, axis_size,
axis_name, axis_index_groups):
if axis_index_groups is not None:
raise NotImplementedError("axis_index_groups not implemented in vmap collectives. "
"Please open a feature request!")
vals_out = [if_mapped(v, d) if d is not batching.not_mapped else if_unmapped(v, axis_size)
for v, d in zip(vals_in, dims_in)]
dims_out = [batching.not_mapped] * len(vals_in)
return vals_out, dims_out

def _replica_groups(axis_env, axis_name, axis_index_groups):
replica_groups = xla.axis_groups(axis_env, axis_name)
if axis_index_groups is not None:
Expand Down Expand Up @@ -386,32 +422,39 @@ def _psum_transpose_rule(cts, axis_name, axis_index_groups):
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args)
ad.deflinear(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
batching.split_axis_rules[psum_p] = partial(_split_axis_comm_assoc, psum_p)
batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p)
batching.collective_rules[psum_p] = \
lambda vals, dims, axis_size, **_: [v.sum(d) if d is not batching.not_mapped else
axis_size * v
for v, d in zip(vals, dims)]
partial(_batched_reduction_collective,
psum_p,
lambda v, d: v.sum(d),
lambda v, axis_size: axis_size * v)


pmax_p = core.Primitive('pmax')
pmax_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
xla.parallel_translations[pmax_p] = \
partial(_allreduce_translation_rule, lax.max_p)
batching.split_axis_rules[pmax_p] = partial(_split_axis_comm_assoc, pmax_p)
batching.primitive_batchers[pmax_p] = partial(_collective_batcher, pmax_p)
batching.collective_rules[pmax_p] = \
lambda vals, dims, axis_size, **_: [v.max(d) if d is not batching.not_mapped else v
for v, d in zip(vals, dims)]
# pxla.split_axis_rules[pmax_p] = \
# partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max)
partial(_batched_reduction_collective,
pmax_p,
lambda v, d: v.max(d),
lambda v, axis_size: v)


pmin_p = core.Primitive('pmin')
pmin_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
xla.parallel_translations[pmin_p] = \
partial(_allreduce_translation_rule, lax.min_p)
batching.split_axis_rules[pmin_p] = partial(_split_axis_comm_assoc, pmin_p)
batching.primitive_batchers[pmin_p] = partial(_collective_batcher, pmin_p)
batching.collective_rules[pmin_p] = \
lambda vals, dims, axis_size, **_: [v.min(d) if d is not batching.not_mapped else v
for v, d in zip(vals, dims)]
# pxla.split_axis_rules[pmin_p] = \
# partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)
partial(_batched_reduction_collective,
pmin_p,
lambda v, d: v.min(d),
lambda v, axis_size: v)


def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform):
Expand All @@ -433,11 +476,22 @@ def _ppermute_transpose_rule(t, perm, axis_name):
inverse_perm = list(zip(dsts, srcs))
return [ppermute(t, axis_name=axis_name, perm=inverse_perm)]

def _ppermute_batcher(vals_in, dims_in, axis_size, axis_name, perm):
assert len(perm) == axis_size, "Permutation doesn't match the axis size!"
perm_indices = np.full((axis_size,), -1, dtype=np.int32)
for s, d in perm:
perm_indices[s] = d
vals_out = [lax_numpy.take(v, perm_indices, d) if d is not batching.not_mapped else v
for v, d in zip(vals_in, dims_in)]
return vals_out, dims_in

ppermute_p = core.Primitive('ppermute')
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
ad.deflinear(ppermute_p, _ppermute_transpose_rule)
xla.parallel_translations[ppermute_p] = _ppermute_translation_rule
pxla.multi_host_supported_collectives.add(ppermute_p)
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, pmin_p)
batching.collective_rules[ppermute_p] = _ppermute_batcher


def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name,
Expand Down
16 changes: 16 additions & 0 deletions tests/batching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,5 +1002,21 @@ def testCollective(self, collective, seq):
vmap(vmap(lambda x: x - collective(x, ('i', 'j')), axis_name='i'), axis_name='j')(x),
x - seq(x, axis=(1, 0)))

@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testPpermute(self):
nelem = 10
ntests = 10
x = np.arange(nelem)
rng = np.random.RandomState(1)
for i in range(ntests):
perm = np.arange(nelem)
rng.shuffle(perm)
perm_pairs = np.stack([np.arange(nelem), perm], axis=-1)
rng.shuffle(perm_pairs)
self.assertAllClose(
vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs)[0], axis_name='i')(x),
x - x[perm])

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 36f3a36

Please sign in to comment.