Skip to content

Commit

Permalink
add test for #2822
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Apr 25, 2020
1 parent c77582c commit d7e2206
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,38 @@ def f(key):
keys = random.split(random.PRNGKey(0), n)
jax.pmap(jax.remat(f), axis_name='i')(keys)

def testPmapMapVmapCombinations(self):
# https://github.com/google/jax/issues/2822
def vv(x, y):
"""Vector-vector multiply"""
return np.dot(x, y)

def matrix_vector(x, y, parallel=True):
"""Matrix vector multiply. First batch it and then row by row"""
fv = lambda z: lax.map(lambda j: vv(j, y), z)
if parallel:
# split leading axis in two
new_x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
# apply map
new_res = pmap(fv)(new_x)
# reshape back out
res = new_res.reshape(x.shape[0], *new_res.shape[2:])
else:
res = fv(x)
return res

x = random.normal(random.PRNGKey(1), (80, 5))
y = random.normal(random.PRNGKey(1), (10, 5))

result1 = vmap(lambda b: matrix_vector(x, b, True))(y) # vmap + pmap
result2 = lax.map(lambda b: matrix_vector(x, b, False), y) # map + map
result3 = lax.map(lambda b: matrix_vector(x, b, True), y) # map + pmap
result4 = np.stack([matrix_vector(x, b, False) for b in y]) # none + map

self.assertAllClose(result1, result2, check_dtypes=False, atol=1e-3, rtol=1e-3)
self.assertAllClose(result1, result3, check_dtypes=False, atol=1e-3, rtol=1e-3)
self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3)


class PmapWithDevicesTest(jtu.JaxTestCase):

Expand Down

0 comments on commit d7e2206

Please sign in to comment.