Skip to content

Commit

Permalink
introduce custom_batching.sequential_vmap
Browse files Browse the repository at this point in the history
An anticipated common use of `custom_vmap` is in order to implement a
map via loop (i.e. to sequentially apply the mapped function), instead
of actually vectorizing.
  • Loading branch information
froystig committed Jan 21, 2022
1 parent 04374c7 commit 30c3a39
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
31 changes: 31 additions & 0 deletions jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,34 @@ def to_vmap_over_extra_batched_dims(primals, tangents):
initial_style=True)
mlir.register_lowering(custom_vmap_p, mlir.lower_fun(
custom_vmap_impl, multiple_results=True))


# -- custom vmap applications


def tree_split(mask, tree):
lhs = tree_map(lambda l, x: x if l else None, mask, tree)
rhs = tree_map(lambda l, x: None if l else x, mask, tree)
return lhs, rhs

def tree_merge(mask, lhs_tree, rhs_tree):
return tree_map(lambda l, x_l, x_r: x_l if l else x_r,
mask, lhs_tree, rhs_tree)

def sequential_vmap(f):
f = custom_vmap(f)

@f.def_vmap
def rule(axis_size, in_batched, *args):
del axis_size

def to_map(mapped_args):
args = tree_merge(in_batched, mapped_args, bcast_args)
return f(*args)

mapped_args, bcast_args = tree_split(in_batched, list(args))
out = jax.lax.map(to_map, mapped_args)
out_batched = tree_map(lambda _: True, out)
return [out], [out_batched]

return f
1 change: 1 addition & 0 deletions jax/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# flake8: noqa: F401
from jax._src.custom_batching import (
custom_vmap,
sequential_vmap,
)
43 changes: 43 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import jax.numpy as jnp
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
from jax import core, lax
from jax import custom_batching
from jax._src import api, dtypes
from jax.core import Primitive
from jax.errors import UnexpectedTracerError
Expand Down Expand Up @@ -6962,6 +6963,48 @@ def rule(axis_size, in_batched, xs):
self.assertAllClose(jit(api.vmap(f))(xs), api.vmap(f)(xs))
self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs))

def test_sequential_vmap_basic(self):
@custom_batching.sequential_vmap
def f(x):
return x + 1.

def vmap_ref(xs):
return lax.map(f, xs)

xs = jnp.arange(3.)
jaxpr = api.make_jaxpr(api.vmap(f))(xs)
jaxpr_ref = api.make_jaxpr(vmap_ref)(xs)

self.assertEqual(str(jaxpr), str(jaxpr_ref))

def test_sequential_vmap_nary_same_batching(self):
@custom_batching.sequential_vmap
def f(x, y):
return x + y

def vmap_ref(xs, ys):
return lax.map(lambda args: f(*args), (xs, ys))

xs, ys = jnp.arange(3.), 4. + jnp.arange(3.)
jaxpr = api.make_jaxpr(api.vmap(f))(xs, ys)
jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, ys)

self.assertEqual(str(jaxpr), str(jaxpr_ref))

def test_sequential_vmap_nary_mixed_batching(self):
@custom_batching.sequential_vmap
def f(x, y):
return x + y

def vmap_ref(xs, y):
return lax.map(lambda x: f(x, y), xs)

xs, y = jnp.arange(3.), 4.
jaxpr = api.make_jaxpr(api.vmap(f, in_axes=(0, None)))(xs, y)
jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, y)

self.assertEqual(str(jaxpr), str(jaxpr_ref))


class CustomApiTest(jtu.JaxTestCase):
"""Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""
Expand Down

0 comments on commit 30c3a39

Please sign in to comment.