-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
custom batching (vmap) #9073
Comments
A rough update on where the implementation is at present:
We're also thinking about whether to recommend the general |
Let's look to fix/support #13283 as part of this as well. |
By the way, I notice that the above example doesn't include |
…d docstrings. The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs. One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem. PiperOrigin-RevId: 702731889
…d docstrings. The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs. One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem. PiperOrigin-RevId: 702731889
…d docstrings. The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs. One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem. PiperOrigin-RevId: 702731889
…d docstrings. The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs. One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem. PiperOrigin-RevId: 702731889
…d docstrings. The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs. One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem. PiperOrigin-RevId: 702731889
…d docstrings. The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs. One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem. PiperOrigin-RevId: 702731889
…d docstrings. The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs. One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem. PiperOrigin-RevId: 702731889
…d docstrings. The `custom_vmap` API is discussed in #9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs. One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem. PiperOrigin-RevId: 704353963
Support custom batching, i.e. the ability to register a custom "vmap rule" for any given function. Example usage would look something like:
This would enable #7199 and would help avoid #8853, among other things (e.g. #12345).
The text was updated successfully, but these errors were encountered: