You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
vmap is amazing but unfortunately for control flow or memory reasons it is sometimes necessary to resort to jax.lax.scan. scan though is less convenient to use than vmap because of its comparatively restrictive input and output schema. It always loops over the first axes, constants can only be passed to the loop body using partial function evaluations (leading to unnecessary recompiles #14743) and its output always stacks pytrees along the first array axes.
Arguably it is convenient to be able to loop over any axis of the input and stack the output along any axis. None of this is tricky to do and there are a couple of simple wrappers floating around that take in_axes and out_axes arguments akin to vmap, e.g. flax.core.axes_scan.
I think it would make sense to have an extended version of scan in JAX. This version could either be a simple sequential drop in replacement of vmap or a more flexible scan with a carry. I don't think it is necessary to make jax.lax.scan axis aware and I would instead merely argue in favor of a simple wrapper. Phrased differently, people will transpose axes to make scan happy, so why not do it for them?
vmap
is amazing but unfortunately for control flow or memory reasons it is sometimes necessary to resort tojax.lax.scan
.scan
though is less convenient to use thanvmap
because of its comparatively restrictive input and output schema. It always loops over the first axes, constants can only be passed to the loop body using partial function evaluations (leading to unnecessary recompiles #14743) and its output always stacks pytrees along the first array axes.Arguably it is convenient to be able to loop over any axis of the input and stack the output along any axis. None of this is tricky to do and there are a couple of simple wrappers floating around that take
in_axes
andout_axes
arguments akin tovmap
, e.g.flax.core.axes_scan
.I think it would make sense to have an extended version of
scan
in JAX. This version could either be a simple sequential drop in replacement ofvmap
or a more flexiblescan
with acarry
. I don't think it is necessary to makejax.lax.scan
axis aware and I would instead merely argue in favor of a simple wrapper. Phrased differently, people will transpose axes to makescan
happy, so why not do it for them?I have a sequential
vmap
drop in replacementsmap
at https://gist.github.com/Edenhofer/34207ad5b2b60e564e21bde9c350efd6 that I am happy to polish up if merely a sequentialvmap
alternative is desirable but I would also be happy to have something along the lines offlax.core.axes_scan
in JAX.Related
in_axes
andout_axes
including carryThe text was updated successfully, but these errors were encountered: