-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Primitive to sequentially execute a function inside a vmapped function #7199
Comments
The "sequential_vmap" concept seems like it would be easy to implement given a more general custom batching interface, which is something that @mattjj and I have considered before. Short of that, it might be worth considering offering the special case, by writing a single higher-order primitive whose batching rule is essentially a call to |
@froystig Would breaking a |
The original request is about modifying the behavior of a function under |
@froystig Apologies, I see that this was not the right place to ask that question -- You did answer my question, though, so thanks |
No problem. Glad that helped! |
More general "custom batching" could be interesting for use-cases like
This could definitely be a good place to start, even if only as the first step towards the general solution. |
I tend to agree, like having written |
I am writing some longer algorithms which I like to vmap. I stumbled over some problems when combining vmap with the
host_callback
module and thelax.cond
function:host_callback.call
functionI think a simple solution would be to implement a
stop_vmap
orsequential_vmap
decorator. This decorator would define a batching rule, such thatwould be the same as writing
The advantage of the decorator would be that
some_fun
could be used inside a much bigger vmapped function.The text was updated successfully, but these errors were encountered: