diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 1f0d42f55d20..04cc13f3e7e8 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1050,6 +1050,38 @@ def f(key): keys = random.split(random.PRNGKey(0), n) jax.pmap(jax.remat(f), axis_name='i')(keys) + def testPmapMapVmapCombinations(self): + # https://github.com/google/jax/issues/2822 + def vv(x, y): + """Vector-vector multiply""" + return np.dot(x, y) + + def matrix_vector(x, y, parallel=True): + """Matrix vector multiply. First batch it and then row by row""" + fv = lambda z: lax.map(lambda j: vv(j, y), z) + if parallel: + # split leading axis in two + new_x = x.reshape((jax.device_count(), -1, *x.shape[1:])) + # apply map + new_res = pmap(fv)(new_x) + # reshape back out + res = new_res.reshape(x.shape[0], *new_res.shape[2:]) + else: + res = fv(x) + return res + + x = random.normal(random.PRNGKey(1), (80, 5)) + y = random.normal(random.PRNGKey(1), (10, 5)) + + result1 = vmap(lambda b: matrix_vector(x, b, True))(y) # vmap + pmap + result2 = lax.map(lambda b: matrix_vector(x, b, False), y) # map + map + result3 = lax.map(lambda b: matrix_vector(x, b, True), y) # map + pmap + result4 = np.stack([matrix_vector(x, b, False) for b in y]) # none + map + + self.assertAllClose(result1, result2, check_dtypes=False, atol=1e-3, rtol=1e-3) + self.assertAllClose(result1, result3, check_dtypes=False, atol=1e-3, rtol=1e-3) + self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3) + class PmapWithDevicesTest(jtu.JaxTestCase):