Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom batching (vmap) #9073

Open
froystig opened this issue Dec 30, 2021 · 3 comments
Open

custom batching (vmap) #9073

froystig opened this issue Dec 30, 2021 · 3 comments
Assignees
Labels
enhancement New feature or request P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)

Comments

@froystig
Copy link
Member

froystig commented Dec 30, 2021

Support custom batching, i.e. the ability to register a custom "vmap rule" for any given function. Example usage would look something like:

from jax import vmap, custom_vmap, numpy as jnp

@custom_vmap
def vector_dot(u, v):
  assert u.ndim == v.ndim == 1
  return u @ v

@vector_dot.def_vmap
def vector_dot_vmap_rule(axis_size, in_batched, u, v):
  u_batched, v_batched = in_batched
  if u_batched:
    assert u.ndim == 2 and u.shape[0] == axis_size
    print('lhs batched')
  if v_batched:
    assert v.ndim == 2 and v.shape[0] == axis_size
    print('rhs batched')
  if u_batched and v_batched:
    out = jnp.sum(u * v, axis=1)
  else:
    out = u @ v if u_batched else v @ u
  return out, u_batched or v_batched

def f(u, v):
  return jnp.exp(vector_dot(u, v))

x = lambda *shape: jnp.ones(shape)
vmap(f, in_axes=(0, None))(x(4, 3), x(3))  # -> lhs batched
vmap(f, in_axes=(1, None))(x(3, 4), x(3))  # -> lhs batched
vmap(f, in_axes=(None, 0))(x(3), x(4, 3))  # -> rhs batched
vmap(f, in_axes=(0, 0))(x(4, 3), x(4, 3))  # -> lhs batched, rhs batched

This would enable #7199 and would help avoid #8853, among other things (e.g. #12345).

@froystig
Copy link
Member Author

A rough update on where the implementation is at present:

  • Batching, forward-mode AD (e.g. jvp, jacfwd), and compilation are all supported.
  • Reverse-mode AD (specifically linearization by partial evaluation, and transposition) are a work in progress.
  • The underlying primitive currently stages out the custom-batched function eagerly. We may want to move to a delayed tracing approach.

We're also thinking about whether to recommend the general custom_vmap function as a direct user-facing API, or whether instead to encourage more structured uses via functions like sequential_vmap or variations thereof.

@froystig
Copy link
Member Author

Let's look to fix/support #13283 as part of this as well.

@patrick-kidger
Copy link
Collaborator

By the way, I notice that the above example doesn't include axis_name, which in general needs be passed in as well.

copybara-service bot pushed a commit that referenced this issue Dec 4, 2024
…d docstrings.

The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 702731889
copybara-service bot pushed a commit that referenced this issue Dec 4, 2024
…d docstrings.

The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 702731889
copybara-service bot pushed a commit that referenced this issue Dec 4, 2024
…d docstrings.

The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 702731889
copybara-service bot pushed a commit that referenced this issue Dec 4, 2024
…d docstrings.

The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 702731889
copybara-service bot pushed a commit that referenced this issue Dec 4, 2024
…d docstrings.

The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 702731889
copybara-service bot pushed a commit that referenced this issue Dec 5, 2024
…d docstrings.

The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 702731889
copybara-service bot pushed a commit that referenced this issue Dec 9, 2024
…d docstrings.

The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 702731889
copybara-service bot pushed a commit that referenced this issue Dec 9, 2024
…d docstrings.

The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 704353963
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)
Projects
None yet
Development

No branches or pull requests

2 participants