Skip to content

Commit

Permalink
add test and fix jaxpr in docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jekbradbury committed Mar 11, 2020
1 parent fdc6b57 commit 8c94cbc
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
3 changes: 2 additions & 1 deletion docs/jaxpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,8 @@ example::
call_jaxpr={ lambda ; d b a.
let c = add a b
e = add c d
f = psum[ axis_name=rows ] a
f = psum[ axis_name=rows
replica_groups=None ] a
g = div e f
in g }
devices=None
Expand Down
3 changes: 2 additions & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,7 +2030,8 @@ def inner(x):
call_jaxpr={ lambda ; d b a.
let c = add a b
e = add c d
f = psum[ axis_name=rows ] a
f = psum[ axis_name=rows
replica_groups=None ] a
g = div e f
in g }
devices=None
Expand Down
19 changes: 19 additions & 0 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,25 @@ def sum_and_broadcast(x, axis):
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
self.assertAllClose(ans, expected, check_dtypes=False)

def testPsumReplicaGroups(self):
replicas = xla_bridge.device_count()
if replicas % 2 != 0:
raise SkipTest
replica_groups = onp.arange(replicas).reshape(2, replicas // 2).tolist()
f = pmap(lambda x: x - lax.psum(x, 'i', replica_groups), axis_name='i')

shape = (replicas, 4)
x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
expected_psum_1 = onp.broadcast_to(
onp.sum(x[:replicas // 2], 0), (replicas // 2, x.shape[1]))
expected_psum_2 = onp.broadcast_to(
onp.sum(x[replicas // 2:], 0), (replicas // 2, x.shape[1]))
expected_psum = onp.concatenate([expected_psum_1, expected_psum_2], 0)
expected = x - expected_psum

ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)

def testAxisGroups(self):
axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2))
groups = xla.axis_groups(axis_env, 'i')
Expand Down

0 comments on commit 8c94cbc

Please sign in to comment.