Skip to content

Commit

Permalink
switch from XLA replica IDs to JAX axis indices
Browse files Browse the repository at this point in the history
  • Loading branch information
jekbradbury committed May 1, 2020
1 parent 6bc5528 commit 564e4a2
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 34 deletions.
2 changes: 1 addition & 1 deletion docs/jaxpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ example
let c = add a b
e = add c d
f = psum[ axis_name=rows
replica_groups=None ] a
axis_index_groups=None ] a
g = div e f
in (g,) }
devices=None
Expand Down
6 changes: 3 additions & 3 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ def apply_parallel_primitive(prim, *args, **params):
# look up information in the dynamic axis env.
dynamic_axis_env = _thread_local_state.dynamic_axis_env
axis_name = params.pop('axis_name')
replica_groups = params.pop('replica_groups')
if replica_groups is not None:
shape = (len(replica_groups[0]),)
axis_index_groups = params.pop('axis_index_groups')
if axis_index_groups is not None:
shape = (len(axis_index_groups[0]),)
else:
logical_size = lambda frame: frame.hard_size * (frame.soft_size or 1)
if isinstance(axis_name, (list, tuple)):
Expand Down
11 changes: 7 additions & 4 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,14 @@ def write(v, node):
ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name),
map(aval, eqn.invars), backend, *in_nodes, **new_params)
elif eqn.primitive in parallel_translations:
replica_groups = eqn.params.get('replica_groups', None)
if replica_groups is None:
replica_groups = axis_groups(axis_env, eqn.params['axis_name'])
replica_groups = axis_groups(axis_env, eqn.params['axis_name'])
axis_index_groups = eqn.params.get('axis_index_groups', None)
if axis_index_groups is not None:
replica_groups = [[axis_group[i] for i in axis_index_group]
for axis_group in replica_groups
for axis_index_group in axis_index_groups]
new_params = {k: v for k, v in eqn.params.items()
if k not in ('axis_name', 'replica_groups')}
if k not in ('axis_name', 'axis_index_groups')}
rule = parallel_translations[eqn.primitive]
ans = rule(c, *in_nodes, replica_groups=replica_groups, platform=platform,
**new_params)
Expand Down
32 changes: 16 additions & 16 deletions jax/lax/lax_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

### parallel traceables

def psum(x, axis_name, replica_groups=None):
def psum(x, axis_name, axis_index_groups=None):
"""Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``.
If ``x`` is a pytree then the result is equivalent to mapping this function to
Expand All @@ -48,7 +48,7 @@ def psum(x, axis_name, replica_groups=None):
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
replica_groups: optional list of lists containing replica IDs (e.g. for
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
two and last two replicas).
Expand All @@ -68,9 +68,9 @@ def psum(x, axis_name, replica_groups=None):
"""
leaves, treedef = tree_util.tree_flatten(x)
return treedef.unflatten(
psum_p.bind(*leaves, axis_name=axis_name, replica_groups=replica_groups))
psum_p.bind(*leaves, axis_name=axis_name, axis_index_groups=axis_index_groups))

def pmean(x, axis_name, replica_groups=None):
def pmean(x, axis_name, axis_index_groups=None):
"""Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``.
If ``x`` is a pytree then the result is equivalent to mapping this function to
Expand All @@ -80,7 +80,7 @@ def pmean(x, axis_name, replica_groups=None):
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
replica_groups: optional list of lists containing replica IDs (e.g. for
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first
two and last two replicas).
Expand All @@ -98,10 +98,10 @@ def pmean(x, axis_name, replica_groups=None):
>>> print(y)
[ 0. 0.66666667 1.33333334 2.0 ]
"""
x, n = psum((x, 1), axis_name=axis_name, replica_groups=replica_groups)
x, n = psum((x, 1), axis_name=axis_name, axis_index_groups=axis_index_groups)
return tree_util.tree_map(lambda v: v / n, x)

def pmax(x, axis_name, replica_groups=None):
def pmax(x, axis_name, axis_index_groups=None):
"""Compute an all-reduce max on ``x`` over the pmapped axis ``axis_name``.
If ``x`` is a pytree then the result is equivalent to mapping this function to
Expand All @@ -111,7 +111,7 @@ def pmax(x, axis_name, replica_groups=None):
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
replica_groups: optional list of lists containing replica IDs (e.g. for
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the first
two and last two replicas).
Expand All @@ -120,9 +120,9 @@ def pmax(x, axis_name, replica_groups=None):
all-reduce max along the axis ``axis_name``.
"""
return tree_util.tree_map(partial(
pmax_p.bind, axis_name=axis_name, replica_groups=replica_groups), x)
pmax_p.bind, axis_name=axis_name, axis_index_groups=axis_index_groups), x)

def pmin(x, axis_name, replica_groups=None):
def pmin(x, axis_name, axis_index_groups=None):
"""Compute an all-reduce min on ``x`` over the pmapped axis ``axis_name``.
If ``x`` is a pytree then the result is equivalent to mapping this function to
Expand All @@ -132,7 +132,7 @@ def pmin(x, axis_name, replica_groups=None):
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
replica_groups: optional list of lists containing replica IDs (e.g. for
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform pmins over the first
two and last two replicas).
Expand All @@ -141,7 +141,7 @@ def pmin(x, axis_name, replica_groups=None):
all-reduce min along the axis ``axis_name``.
"""
return tree_util.tree_map(partial(
pmin_p.bind, axis_name=axis_name, replica_groups=replica_groups), x)
pmin_p.bind, axis_name=axis_name, axis_index_groups=axis_index_groups), x)

def ppermute(x, axis_name, perm):
"""Perform a collective permutation according to the permutation ``perm``.
Expand Down Expand Up @@ -273,9 +273,9 @@ def standard_pmap_primitive(name, multiple_results=False):


def _allreduce_split_axis_rule(prim, reducer, vals, which_mapped, axis_name,
replica_groups):
axis_index_groups):
assert tuple(which_mapped) == (True,)
assert replica_groups is None
assert axis_index_groups is None
vals = (reducer(x, [0]) for x in vals)
return prim.bind(*vals, axis_name=axis_name), False

Expand Down Expand Up @@ -341,8 +341,8 @@ def _translate(val):
partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args)
ad.deflinear(psum_p, lambda *ts, axis_name, replica_groups: psum(
*ts, axis_name=axis_name, replica_groups=replica_groups))
ad.deflinear(psum_p, lambda *ts, axis_name, axis_index_groups: psum(
*ts, axis_name=axis_name, axis_index_groups=axis_index_groups))
pxla.multi_host_supported_collectives.add(psum_p)


Expand Down
18 changes: 14 additions & 4 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,9 +810,9 @@ def test_grad_of_int_errors(self):
def test_xla_computation(self):
# these tests basically check the examples in the xla_computation docstring

def h(x):
def e(x):
return np.sin(np.cos(x))
c = api.xla_computation(h)(2.)
c = api.xla_computation(e)(2.)
self.assertIn('cosine', c.GetHloText())
self.assertIn('sine', c.GetHloText())

Expand All @@ -835,6 +835,16 @@ def g(x):
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.GetHloText())
self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.GetHloText())

def h(x):
rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]])
colsum = lax.psum(x, 'j')
return rowsum, colsum
axis_env = [('i', 4), ('j', 2)]
c = api.xla_computation(h, axis_env=axis_env)(5.)
self.assertIn('all-reduce', c.GetHloText())
self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.GetHloText())
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.GetHloText())

def test_xla_computation_args(self):
def foo(x, y, z):
return x + y + z
Expand Down Expand Up @@ -1852,8 +1862,8 @@ def inner(x):
call_jaxpr={ lambda ; d b a.
let c = add a b
e = add c d
f = psum[ axis_name=rows
replica_groups=None ] a
f = psum[ axis_index_groups=None
axis_name=rows ] a
g = div e f
in (g,) }
devices=None
Expand Down
37 changes: 31 additions & 6 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,21 +388,46 @@ def testPsumReplicaGroups(self):
replicas = xla_bridge.device_count()
if replicas % 2 != 0:
raise SkipTest
replica_groups = onp.arange(replicas).reshape(2, replicas // 2).tolist()
f = pmap(lambda x: x - lax.psum(x, 'i', replica_groups), axis_name='i')
axis_index_groups = onp.arange(replicas).reshape(
2, replicas // 2).tolist()
f = lambda x: x - lax.psum(x, 'i', axis_index_groups=axis_index_groups)
f = pmap(f, 'i')

shape = (replicas, 4)
x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
expected_psum_1 = onp.broadcast_to(
onp.sum(x[:replicas // 2], 0), (replicas // 2, x.shape[1]))
expected_psum_2 = onp.broadcast_to(
onp.sum(x[replicas // 2:], 0), (replicas // 2, x.shape[1]))
def sum_helper(a):
return onp.broadcast_to(a.sum(0, keepdims=True),
(replicas // 2, x.shape[1]))
expected_psum_1 = sum_helper(x[:replicas // 2])
expected_psum_2 = sum_helper(x[replicas // 2:])
expected_psum = onp.concatenate([expected_psum_1, expected_psum_2], 0)
expected = x - expected_psum

ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)

def testNestedPmapReplicaGroups(self):
replicas = xla_bridge.device_count()
if replicas % 4 != 0:
raise SkipTest
axis_index_groups = onp.arange(replicas // 2).reshape(
2, replicas // 4).tolist()
f = lambda x: x - lax.psum(x, 'i', axis_index_groups=axis_index_groups)
f = pmap(pmap(f, 'i'), 'j')

shape = (2, replicas // 2, 4)
x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
def sum_helper(a):
return onp.broadcast_to(a.sum(1, keepdims=True),
(2, replicas // 4, x.shape[2]))
expected_psum_1 = sum_helper(x[:, :replicas // 4])
expected_psum_2 = sum_helper(x[:, replicas // 4:])
expected_psum = onp.concatenate([expected_psum_1, expected_psum_2], 1)
expected = x - expected_psum

ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)

def testAxisGroups(self):
axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2))
groups = xla.axis_groups(axis_env, 'i')
Expand Down

0 comments on commit 564e4a2

Please sign in to comment.