diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 3184003d7701..41a4512e07a1 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -237,3 +237,34 @@ def to_vmap_over_extra_batched_dims(primals, tangents): initial_style=True) mlir.register_lowering(custom_vmap_p, mlir.lower_fun( custom_vmap_impl, multiple_results=True)) + + +# -- custom vmap applications + + +def tree_split(mask, tree): + lhs = tree_map(lambda l, x: x if l else None, mask, tree) + rhs = tree_map(lambda l, x: None if l else x, mask, tree) + return lhs, rhs + +def tree_merge(mask, lhs_tree, rhs_tree): + return tree_map(lambda l, x_l, x_r: x_l if l else x_r, + mask, lhs_tree, rhs_tree) + +def sequential_vmap(f): + f = custom_vmap(f) + + @f.def_vmap + def rule(axis_size, in_batched, *args): + del axis_size + + def to_map(mapped_args): + args = tree_merge(in_batched, mapped_args, bcast_args) + return f(*args) + + mapped_args, bcast_args = tree_split(in_batched, list(args)) + out = jax.lax.map(to_map, mapped_args) + out_batched = tree_map(lambda _: True, out) + return [out], [out_batched] + + return f diff --git a/jax/custom_batching.py b/jax/custom_batching.py index d8f6317eb8d2..7e18733f2001 100644 --- a/jax/custom_batching.py +++ b/jax/custom_batching.py @@ -15,4 +15,5 @@ # flake8: noqa: F401 from jax._src.custom_batching import ( custom_vmap, + sequential_vmap, ) diff --git a/tests/api_test.py b/tests/api_test.py index e50ef7098cfa..033b7d49dc82 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -42,6 +42,7 @@ import jax.numpy as jnp from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian from jax import core, lax +from jax import custom_batching from jax._src import api, dtypes from jax.core import Primitive from jax.errors import UnexpectedTracerError @@ -6962,6 +6963,48 @@ def rule(axis_size, in_batched, xs): self.assertAllClose(jit(api.vmap(f))(xs), api.vmap(f)(xs)) self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs)) + def test_sequential_vmap_basic(self): + @custom_batching.sequential_vmap + def f(x): + return x + 1. + + def vmap_ref(xs): + return lax.map(f, xs) + + xs = jnp.arange(3.) + jaxpr = api.make_jaxpr(api.vmap(f))(xs) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + def test_sequential_vmap_nary_same_batching(self): + @custom_batching.sequential_vmap + def f(x, y): + return x + y + + def vmap_ref(xs, ys): + return lax.map(lambda args: f(*args), (xs, ys)) + + xs, ys = jnp.arange(3.), 4. + jnp.arange(3.) + jaxpr = api.make_jaxpr(api.vmap(f))(xs, ys) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, ys) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + def test_sequential_vmap_nary_mixed_batching(self): + @custom_batching.sequential_vmap + def f(x, y): + return x + y + + def vmap_ref(xs, y): + return lax.map(lambda x: f(x, y), xs) + + xs, y = jnp.arange(3.), 4. + jaxpr = api.make_jaxpr(api.vmap(f, in_axes=(0, None)))(xs, y) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, y) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""