Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequential map with in_axes and out_axes #15041

Open
Edenhofer opened this issue Mar 16, 2023 · 0 comments
Open

Sequential map with in_axes and out_axes #15041

Edenhofer opened this issue Mar 16, 2023 · 0 comments
Labels
enhancement New feature or request

Comments

@Edenhofer
Copy link
Contributor

Edenhofer commented Mar 16, 2023

vmap is amazing but unfortunately for control flow or memory reasons it is sometimes necessary to resort to jax.lax.scan. scan though is less convenient to use than vmap because of its comparatively restrictive input and output schema. It always loops over the first axes, constants can only be passed to the loop body using partial function evaluations (leading to unnecessary recompiles #14743) and its output always stacks pytrees along the first array axes.

Arguably it is convenient to be able to loop over any axis of the input and stack the output along any axis. None of this is tricky to do and there are a couple of simple wrappers floating around that take in_axes and out_axes arguments akin to vmap, e.g. flax.core.axes_scan.

I think it would make sense to have an extended version of scan in JAX. This version could either be a simple sequential drop in replacement of vmap or a more flexible scan with a carry. I don't think it is necessary to make jax.lax.scan axis aware and I would instead merely argue in favor of a simple wrapper. Phrased differently, people will transpose axes to make scan happy, so why not do it for them?

I have a sequential vmap drop in replacement smap at https://gist.github.com/Edenhofer/34207ad5b2b60e564e21bde9c350efd6 that I am happy to polish up if merely a sequential vmap alternative is desirable but I would also be happy to have something along the lines of flax.core.axes_scan in JAX.

Related

@Edenhofer Edenhofer added the enhancement New feature or request label Mar 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant