Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce custom_batching.sequential_vmap #9275

Merged
merged 1 commit into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,34 @@ def to_vmap_over_extra_batched_dims(primals, tangents):
initial_style=True)
mlir.register_lowering(custom_vmap_p, mlir.lower_fun(
custom_vmap_impl, multiple_results=True))


# -- custom vmap applications


def tree_split(mask, tree):
lhs = tree_map(lambda l, x: x if l else None, mask, tree)
rhs = tree_map(lambda l, x: None if l else x, mask, tree)
return lhs, rhs

def tree_merge(mask, lhs_tree, rhs_tree):
return tree_map(lambda l, x_l, x_r: x_l if l else x_r,
mask, lhs_tree, rhs_tree)

def sequential_vmap(f):
f = custom_vmap(f)

@f.def_vmap
def rule(axis_size, in_batched, *args):
del axis_size

def to_map(mapped_args):
args = tree_merge(in_batched, mapped_args, bcast_args)
return f(*args)

mapped_args, bcast_args = tree_split(in_batched, list(args))
out = jax.lax.map(to_map, mapped_args)
out_batched = tree_map(lambda _: True, out)
return [out], [out_batched]

return f
1 change: 1 addition & 0 deletions jax/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# flake8: noqa: F401
from jax._src.custom_batching import (
custom_vmap,
sequential_vmap,
)
43 changes: 43 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import jax.numpy as jnp
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
from jax import core, lax
from jax import custom_batching
from jax._src import api, dtypes
from jax.core import Primitive
from jax.errors import UnexpectedTracerError
Expand Down Expand Up @@ -6962,6 +6963,48 @@ def rule(axis_size, in_batched, xs):
self.assertAllClose(jit(api.vmap(f))(xs), api.vmap(f)(xs))
self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs))

def test_sequential_vmap_basic(self):
@custom_batching.sequential_vmap
def f(x):
return x + 1.

def vmap_ref(xs):
return lax.map(f, xs)

xs = jnp.arange(3.)
jaxpr = api.make_jaxpr(api.vmap(f))(xs)
jaxpr_ref = api.make_jaxpr(vmap_ref)(xs)

self.assertEqual(str(jaxpr), str(jaxpr_ref))

def test_sequential_vmap_nary_same_batching(self):
@custom_batching.sequential_vmap
def f(x, y):
return x + y

def vmap_ref(xs, ys):
return lax.map(lambda args: f(*args), (xs, ys))

xs, ys = jnp.arange(3.), 4. + jnp.arange(3.)
jaxpr = api.make_jaxpr(api.vmap(f))(xs, ys)
jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, ys)

self.assertEqual(str(jaxpr), str(jaxpr_ref))

def test_sequential_vmap_nary_mixed_batching(self):
@custom_batching.sequential_vmap
def f(x, y):
return x + y

def vmap_ref(xs, y):
return lax.map(lambda x: f(x, y), xs)

xs, y = jnp.arange(3.), 4.
jaxpr = api.make_jaxpr(api.vmap(f, in_axes=(0, None)))(xs, y)
jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, y)

self.assertEqual(str(jaxpr), str(jaxpr_ref))


class CustomApiTest(jtu.JaxTestCase):
"""Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""
Expand Down